From e0af30caaee633640c47d13891760c88ece85206 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 27 Mar 2017 20:32:41 -0700 Subject: [PATCH] [LANG/SCHEDULE] Reduction factor, predicate in reduction. --- include/tvm/expr.h | 2 + include/tvm/ir.h | 2 +- include/tvm/schedule.h | 12 + python/tvm/__init__.py | 2 +- python/tvm/api.py | 21 +- python/tvm/build.py | 73 +++- python/tvm/schedule.py | 23 +- src/api/api_ir.cc | 2 +- src/api/api_lang.cc | 6 + src/codegen/codegen_c.cc | 2 +- src/lang/ir.cc | 5 +- src/lang/operation.cc | 6 - src/op/compute_op.cc | 36 +- src/op/op_util.cc | 58 +-- src/schedule/bound.cc | 188 +--------- src/schedule/message_passing.cc | 343 ++++++++++++++++++ src/schedule/message_passing.h | 81 +++++ src/schedule/schedule_dataflow_rewrite.cc | 149 +++++++- src/schedule/schedule_lang.cc | 67 +++- tests/python/integration/test_reduce.py | 40 +- tests/python/unittest/test_lang_schedule.py | 25 ++ .../unittest/test_schedule_bound_inference.py | 17 + 22 files changed, 837 insertions(+), 323 deletions(-) delete mode 100644 src/lang/operation.cc create mode 100644 src/schedule/message_passing.cc create mode 100644 src/schedule/message_passing.h diff --git a/include/tvm/expr.h b/include/tvm/expr.h index 1e5d0e0a94c0..91efe1727593 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -40,6 +40,8 @@ using Halide::Internal::as_const_uint; using Halide::Internal::const_true; using Halide::Internal::const_false; using Halide::Internal::is_no_op; +using Halide::likely; +using Halide::likely_if_innermost; inline Type TVMShapeIndexType() { if (std::is_signed::value) { diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 2672a49b1bda..5fdc6fa21240 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -41,7 +41,7 @@ struct Reduce : public ExprNode { /*! \brief construct expr from op and rdom */ static Expr make(std::string op, Expr src, Array rdom, - Expr condition = make_const(Bool(1), true)); + Expr condition = const_true()); void VisitAttrs(AttrVisitor* v) final { v->Visit("dtype", &type); diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index dff799186217..93b93a62cc2c 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -210,6 +210,18 @@ class Schedule : public NodeRef { * \return The created tensor. */ Tensor cache_write(const Tensor& tensor, const std::string& scope); + /*! + * \brief Factor a reduction axis in tensor's schedule to be an explicit axis. + * This will create a new stage that generated the new tensor with axis + * as the first dimension. The tensor's body wil be rewriten as a reduction + * over the factored tensor. + * + * \param tensor The tensor to be factored. + * \param axis The reduction axis in tensor's schedule to be factored. + * \return The created factored tensor. + */ + Tensor rfactor(const Tensor& tensor, + const IterVar& axis); /*! * \brief Normalize the schedule. * This is needed before bound inference. diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 51cb4a179436..75d1a727c783 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -19,4 +19,4 @@ from ._base import TVMError from .api import * -from .build import build +from .build import build, lower diff --git a/python/tvm/api.py b/python/tvm/api.py index 9b55801b1f57..9c44f8f5556e 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -372,7 +372,7 @@ def reduce_axis(dom, name="rv"): return _IterVar(dom, name, 2) -def sum(expr, axis): +def sum(expr, axis, where=None): """Create a sum expression over axis Parameters @@ -382,13 +382,16 @@ def sum(expr, axis): axis : IterVar The reduction IterVar axis + + where : optional, Expr + Filtering predicate of the reduction. """ axis = axis if isinstance(axis, list) else [axis] - x = _make.Reduce("Add", expr, axis) + x = _make.Reduce("Add", expr, axis, where) return x -def min(lhs, rhs=None, axis=None): +def min(lhs, rhs=None, axis=None, where=None): """Create a min expression. Parameters @@ -401,6 +404,9 @@ def min(lhs, rhs=None, axis=None): axis : IterVar, optional The reduction IterVar axis + + where : optional, Expr + Filtering predicate of the reduction. """ if rhs and axis: raise ValueError("Can only take one argument, rhs or axis") @@ -409,11 +415,11 @@ def min(lhs, rhs=None, axis=None): if rhs: return _make.Min(lhs, rhs) axis = axis if isinstance(axis, list) else [axis] - x = _make.Reduce("Min", expr, axis) + x = _make.Reduce("Min", expr, axis, where) return x -def max(lhs, rhs=None, axis=None): +def max(lhs, rhs=None, axis=None, where=None): """Create a max expression. Parameters @@ -426,6 +432,9 @@ def max(lhs, rhs=None, axis=None): axis : IterVar, optional The reduction IterVar axis + + where : optional, Expr + Filtering predicate of the reduction. """ if rhs and axis: raise ValueError("Can only take one argument, rhs or axis") @@ -434,7 +443,7 @@ def max(lhs, rhs=None, axis=None): if rhs: return _make.Max(lhs, rhs) axis = axis if isinstance(axis, list) else [axis] - x = _make.Reduce("Max", expr, axis) + x = _make.Reduce("Max", expr, axis, where) return x diff --git a/python/tvm/build.py b/python/tvm/build.py index 8d5ba26c8570..ec5a0dba1c4f 100644 --- a/python/tvm/build.py +++ b/python/tvm/build.py @@ -9,16 +9,15 @@ from . import schedule from . import expr from . import ir_pass +from . import collections from . import codegen -def build(sch, +def lower(sch, args, - target, - target_host="stackvm", name="default_function", binds=None, max_auto_unroll_step=8): - """Build a function with arguments as signiture. + """Lowering step before build into target. Parameters ---------- @@ -28,12 +27,6 @@ def build(sch, args : list of Buffer or Tensor or Var The argument lists to the function. - target : str - The target of the compilation. - - target_host : - Host compilation target, if target is device. - name : str The name of result function. @@ -46,10 +39,8 @@ def build(sch, Returns ------- - f : Function, or pair of functions + f : LoweredFunc The result function. - If the function requires host space allocation, - a pair of functions will be returned. """ binds = {} if binds is None else binds.copy() arg_list = [] @@ -77,6 +68,62 @@ def build(sch, stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step) stmt = ir_pass.Simplify(stmt) fapi = ir_pass.MakeAPI(stmt, name, arg_list, 0) + return fapi + + + +def build(sch, + args=None, + target="llvm", + target_host="stackvm", + name="default_function", + binds=None, + max_auto_unroll_step=8): + """Build a function with arguments as signiture. + + Parameters + ---------- + sch : tvm.Schedule, or LoweredFunc + The schedule to be builded + + args : list of Buffer or Tensor or Var + The argument lists to the function. + + target : str + The target of the compilation. + + target_host : + Host compilation target, if target is device. + + name : str + The name of result function. + + binds : dict, optional + Dictionary that maps the binding of symbolic buffer to Tensor. + By default, a new buffer is created for each tensor in the argument. + + max_auto_unroll_step: int + Maximum step to perform automatic unrolling + + Returns + ------- + f : Function, or pair of functions + The result function. + """ + if isinstance(sch, schedule.Schedule): + if args is None: + raise ValueError("args must be given for build from schedule") + fapi = lower(sch, args, + name=name, + binds=binds, + max_auto_unroll_step=max_auto_unroll_step) + elif isinstance(sch, collections.LoweredFunc): + if args: + raise ValueError("args must be done when build from LoweredFunc") + fapi = sch + else: + raise ValueError("sch have to be Schedule or LoweredFunc") + fsplits = ir_pass.SplitHostDevice(fapi) fsplits = [x for x in fsplits] for i in range(1, len(fsplits)): diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index 813b48b6fc33..dcddafa4c7a6 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -87,6 +87,27 @@ def cache_write(self, tensor, scope): """ return _api_internal._ScheduleCacheWrite(self, tensor, scope) + def rfactor(self, tensor, axis): + """ Factor a reduction axis in tensor's schedule to be an explicit axis. + + This will create a new stage that generated the new tensor with axis + as the first dimension. The tensor's body wil be rewriten as a reduction + over the factored tensor. + + Parameters + ---------- + tensor : Tensor + The tensor to be factored. + axis : IterVar + The reduction axis in the schedule to be factored. + + Returns + ------- + tfactor : Tensor + The created factored tensor. + """ + return _api_internal._ScheduleRFactor(self, tensor, axis) + @register_node class Stage(NodeBase): @@ -114,8 +135,6 @@ def split(self, parent, factor=None, outer=None): The inner variable of iteration. """ if outer is not None: - if outer.thread_tag == '': - raise ValueError("split by outer must have special thread_tag") inner = _api_internal._StageSplitByOuter(self, parent, outer, factor) else: if factor is None: diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index af07958000fa..61cf3e365f9f 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -89,7 +89,7 @@ TVM_REGISTER_API(_make_Allocate) *ret = Node::make(a, b); \ }) -REGISTER_MAKE3(Reduce); +REGISTER_MAKE4(Reduce); REGISTER_MAKE4(AttrStmt); REGISTER_MAKE2(IntImm); diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 14c555fb3fb0..933adc872cc2 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -318,4 +318,10 @@ TVM_REGISTER_API(_ScheduleCacheWrite) .cache_write(args[1], args[2]); }); +TVM_REGISTER_API(_ScheduleRFactor) +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = args[0].operator Schedule() + .rfactor(args[1], args[2]); + }); + } // namespace tvm diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index 55971fe865d3..b288ab82e46d 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -526,8 +526,8 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*) void CodeGenC::VisitStmt_(const Store* op) { Type t = op->value.type(); if (t.lanes() == 1) { - this->PrintIndent(); std::string value = this->PrintExpr(op->value); + this->PrintIndent(); this->PrintBufferRef(op->buffer_var.get(), t, op->index, stream); stream << " = " << value << ";\n"; } else { diff --git a/src/lang/ir.cc b/src/lang/ir.cc index edd93dac1e45..55a4d7a0de56 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -28,7 +28,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->print(op->source); p->stream << ", axis=" << op->axis; if (!is_const(op->condition, 1)) { - p->stream << ", condition=" << op->condition; + p->stream << ", where=" << op->condition; } p->stream << ")"; }); @@ -45,6 +45,9 @@ Expr Reduce::make(std::string op, Expr source, CHECK_EQ(axis[i]->iter_type, kCommReduce) << "Can only take axis created by reduce_axis"; } + if (!condition.defined()) { + condition = const_true(); + } auto n = std::make_shared(); CHECK(source.defined()); for (size_t i = 0; i < axis.size(); ++i) { diff --git a/src/lang/operation.cc b/src/lang/operation.cc deleted file mode 100644 index f6cdaa72b4f0..000000000000 --- a/src/lang/operation.cc +++ /dev/null @@ -1,6 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file operation.cc - */ -#include -#include diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index 7884e454e8da..e2467bc32fcc 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -10,6 +10,7 @@ #include #include #include "./op_util.h" +#include "../schedule/message_passing.h" namespace tvm { @@ -64,10 +65,7 @@ Tensor compute(Array shape, FCompute fcompute, std::string name) { args.push_back(axis.back()->var); } - op_node->axis = Array(axis); - op_node->body = fcompute(args); - op_node->name = name; - return Operation(op_node).output(0); + return ComputeOpNode::make(name, axis, fcompute(args)).output(0); } Operation ComputeOpNode::make(std::string name, @@ -191,6 +189,9 @@ void MakeReduction(const ComputeOpNode* op, } *init = Provide::make(t->op, t->value_index, init_value, args); *provide = Provide::make(t->op, t->value_index, update_value, args); + if (!is_one(reduce->condition)) { + *provide = IfThenElse::make(reduce->condition, *provide); + } } Stmt MakeProvide(const ComputeOpNode* op, @@ -202,31 +203,6 @@ Stmt MakeProvide(const ComputeOpNode* op, return Provide::make(t->op, t->value_index, op->body, args); } -// message passing to find if IterVar is related to reduction. -void PassDownReduceFlag(const Stage& s, - std::unordered_map* p_state) { - auto& state = *p_state; - for (IterVarRelation rel : s->relations) { - if (rel.as()) { - const SplitNode* s = rel.as(); - int flag = state.at(s->parent); - state[s->outer] = flag; - state[s->inner] = flag; - } else if (rel.as()) { - const FuseNode* s = rel.as(); - int flag_outer = state.at(s->outer); - int flag_inner = state.at(s->inner); - state[s->fused] = flag_outer | flag_inner; - } else if (rel.as()) { - const RebaseNode* s = rel.as(); - int flag = state.at(s->parent); - state[s->rebased] = flag; - } else { - LOG(FATAL) << "unknown relation type"; - } - } -} - Stmt Substitute(Stmt s, const std::unordered_map& value_map) { Map temp; @@ -267,7 +243,7 @@ Stmt ComputeOpNode::BuildProvide( update_state[iv] = 1; } // find which iter var is related to reduction and which is related to axis. - PassDownReduceFlag(stage, &update_state); + schedule::PassDownBitMaskOr(stage, &update_state); auto leaf_iter_vars = stage->leaf_iter_vars; std::unordered_map init_value_map; // first first loop that is related to reduction. diff --git a/src/op/op_util.cc b/src/op/op_util.cc index 7aa31405748f..487be17cc80b 100644 --- a/src/op/op_util.cc +++ b/src/op/op_util.cc @@ -8,6 +8,7 @@ #include #include #include "./op_util.h" +#include "../schedule/message_passing.h" #include "../arithmetic/compute_expr.h" namespace tvm { @@ -16,61 +17,6 @@ namespace op { using namespace arith; using namespace ir; -/*! - * \brief use message passing to calculate the assignment of each Var inside the loop body. - * \param s The schedule to be used. - * \param dom_map The domain map of each iteration variable's domain - * \param p_state The message passing state - * IterVar->The assignment. - */ -void PassUpOffset(const Stage& s, - const Map& dom_map, - std::unordered_map* p_state) { - auto& state = *p_state; - for (size_t i = s->relations.size(); i != 0; --i) { - IterVarRelation rel = s->relations[i - 1]; - if (rel.as()) { - const SplitNode* s = rel.as(); - Expr outer = state.at(s->outer); - Expr inner = state.at(s->inner); - Expr factor = dom_map.at(s->inner)->extent; - Expr parent_min = dom_map.at(s->parent)->min; - state[s->parent] = inner + outer * factor; - // add min if they exist - if (!is_zero(parent_min)) { - state[s->parent] = state[s->parent] + parent_min; - } - } else if (rel.as()) { - const FuseNode* s = rel.as(); - Expr value = state.at(s->fused); - Expr factor = dom_map.at(s->inner)->extent; - Expr outer_min = dom_map.at(s->outer)->min; - Expr inner_min = dom_map.at(s->inner)->min; - state[s->outer] = value / factor; - state[s->inner] = value % factor; - // add min if they exist - if (!is_zero(outer_min)) { - state[s->outer] = state[s->outer] + outer_min; - } - if (!is_zero(inner_min)) { - state[s->inner] = state[s->inner] + inner_min; - } - } else if (rel.as()) { - const RebaseNode* s = rel.as(); - Expr value = state.at(s->rebased); - Expr parent_min = dom_map.at(s->parent)->min; - // add min if they exist - if (!is_zero(parent_min)) { - state[s->parent] = value + parent_min; - } else { - state[s->parent] = value; - } - } else { - LOG(FATAL) << "unknown relation type"; - } - } -} - std::vector > MakeLoopNest(const Stage& stage, const std::unordered_map& dom_map, @@ -166,7 +112,7 @@ MakeLoopNest(const Stage& stage, } } // message passing to get offset of root iter vars. - PassUpOffset(stage, dom_map, &value_map); + schedule::PassUpIndex(stage, dom_map, &value_map); return nest; } diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc index 39fbc929ffe6..ae6cf678e023 100644 --- a/src/schedule/bound.cc +++ b/src/schedule/bound.cc @@ -3,200 +3,18 @@ * \file bound.cc * \brief The bound inference logic. */ -#include #include -#include #include -#include #include #include #include #include "./graph.h" +#include "./message_passing.h" #include "../runtime/thread_storage_scope.h" namespace tvm { namespace schedule { -using namespace arith; - -// result = ceil((a / b)), both a and b are positive integer -inline Expr DivCeil(Expr a, Expr b) { - return ir::Simplify((a + b - 1) / b); -} - -inline bool prove_equal(Expr lhs, Expr rhs) { - return is_zero(ir::Simplify(lhs - rhs)); -} - -// Downward message passing algorithm on stage schedule s, -// pass the range state down from the root to the leaves -// after this pass, every IterVar in the stage hyper graph will have a range(domain) -void PassDown(const Stage& s, - std::unordered_map* p_state) { - auto& state = *p_state; - // forwar iteration on relations - for (IterVarRelation rel : s->relations) { - if (rel.as()) { - const SplitNode* r = rel.as(); - CHECK(state.count(r->parent)); - CHECK(!state.count(r->inner)); - const Range& range_parent = state.at(r->parent); - if (r->factor.defined()) { - state[r->inner] = Range::make_with_min_extent(0, r->factor); - if (r->outer->dom.defined()) { - state[r->outer] = r->outer->dom; - } else { - if (!state.count(r->outer)) { - state[r->outer] = Range::make_with_min_extent( - 0, DivCeil(range_parent->extent, r->factor)); - } else { - Expr outer_ext = DivCeil(range_parent->extent, r->factor); - Range outer_rng = state.at(r->outer); - bool match = is_zero(outer_rng->min); - if (!prove_equal(outer_ext, outer_rng->extent)) match = false; - CHECK(match) - << r->outer - << "IterVar is used in two places as outer scope," - << " cannot prove their extents are the same " - << outer_ext << " vs " << outer_rng->extent; - } - } - } else { - CHECK(r->outer->dom.defined()); - state[r->outer] = r->outer->dom; - state[r->inner] = Range::make_with_min_extent( - 0, DivCeil(range_parent->extent, r->outer->dom->extent)); - } - } else if (rel.as()) { - const FuseNode* r = rel.as(); - CHECK(state.count(r->outer)); - CHECK(state.count(r->inner)); - const Range& range_outer = state.at(r->outer); - const Range& range_inner = state.at(r->inner); - state[r->fused] = Range::make_with_min_extent( - 0, range_outer->extent * range_inner->extent); - } else if (rel.as()) { - const RebaseNode* r = rel.as(); - CHECK(state.count(r->parent)); - state[r->rebased] = Range::make_with_min_extent( - 0, state.at(r->parent)->extent); - } else { - LOG(FATAL) << "unknown relation type"; - } - } -} - -// upward message passing algorithm -// pass the integer set on each leave loop up to the root -// dom_map is the result of PassDown, it records the domain of each IterVar. -// dom_map can be used to get cached result in reverse construction. -// Implementation of Evaluations and passing. -void PassUp(const SplitNode* s, - const std::unordered_map& dom_map, - const IntSet& outer, - const IntSet& inner, - IntSet* parent) { - if (dom_map.count(s->outer) && - dom_map.count(s->inner) && - dom_map.count(s->parent) && - outer.match_range(dom_map.at(s->outer)) && - inner.match_range(dom_map.at(s->inner))) { - *parent = IntSet::range(dom_map.at(s->parent)); - return; - } - Expr factor = dom_map.at(s->inner)->extent; - Expr parent_min = dom_map.at(s->parent)->min; - CHECK(outer.defined()); - CHECK(inner.defined()); - CHECK(factor.defined()); - *parent = EvalSet( - s->outer->var * factor + s->inner->var + parent_min, - {{s->outer, outer}, {s->inner, inner}}); -} - -void PassUp(const FuseNode* s, - const std::unordered_map& dom_map, - const IntSet& fused, - IntSet* outer, - IntSet* inner) { - CHECK(dom_map.count(s->outer)); - CHECK(dom_map.count(s->inner)); - CHECK(dom_map.count(s->fused)); - - if (fused.match_range(dom_map.at(s->fused))) { - *outer = IntSet::range(dom_map.at(s->outer)); - *inner = IntSet::range(dom_map.at(s->inner)); - return; - } - Expr outer_min = dom_map.at(s->outer)->min; - Expr inner_min = dom_map.at(s->inner)->min; - - if (fused.is_single_point()) { - Expr value = fused.point_value(); - Expr factor = dom_map.at(s->inner)->extent; - Expr v_outer = value / factor; - Expr v_inner = value % factor; - if (!is_zero(outer_min)) v_outer = v_outer + outer_min; - if (!is_zero(inner_min)) v_inner = v_inner + inner_min; - *outer = IntSet::single_point(v_outer); - *inner = IntSet::single_point(v_inner); - } else { - LOG(WARNING) << "use fallback inference rule in fuse"; - // simply use the entire set, this rule can be enhanced. - *outer = IntSet::range(dom_map.at(s->outer)); - *inner = IntSet::range(dom_map.at(s->inner)); - return; - } -} - -void PassUp(const RebaseNode* s, - const std::unordered_map& dom_map, - const IntSet& rebased, - IntSet* parent) { - CHECK(dom_map.count(s->parent)); - if (rebased.match_range(dom_map.at(s->rebased))) { - *parent = IntSet::range(dom_map.at(s->parent)); - return; - } - Expr parent_min = dom_map.at(s->parent)->min; - *parent = EvalSet(s->rebased->var + parent_min, - {{s->rebased, rebased}}); -} - -void PassUp(const Stage& s, - const std::unordered_map& dom_map, - std::unordered_map* p_state) { - auto& state = *p_state; - for (size_t i = s->relations.size(); i != 0; --i) { - IterVarRelation rel = s->relations[i - 1]; - if (rel.as()) { - IntSet parent; - const SplitNode* r = rel.as(); - PassUp(r, dom_map, - state.at(r->outer), state.at(r->inner), - &parent); - state[r->parent] = parent; - } else if (rel.as()) { - IntSet outer, inner; - const FuseNode* r = rel.as(); - PassUp(r, dom_map, - state.at(r->fused), - &outer, &inner); - state[r->outer] = outer; - state[r->inner] = inner; - } else if (rel.as()) { - IntSet parent; - const RebaseNode* r = rel.as(); - PassUp(r, dom_map, - state.at(r->rebased), - &parent); - state[r->parent] = parent; - } else { - LOG(FATAL) << "unknown relation type"; - } - } -} - // check if scope inline bool ScopeRelax(const IterVar& iv, const std::string& scope) { using runtime::ThreadScope; @@ -285,7 +103,7 @@ void InferRootBound(const Stage& stage, } } // get the bound of the root IterVars given current location. - PassUp(parent, *rmap, &up_state); + PassUpDomain(parent, *rmap, &up_state); std::unordered_map dom_map; for (auto iv : parent->op->root_iter_vars()) { @@ -358,7 +176,7 @@ Map InferBound(const Schedule& sch) { const Stage& stage = sch->stages[i - 1]; InferRootBound(stage, ctx, attach_path, &ret); // pass down to get bound of all iter vars. - PassDown(stage, &ret); + PassDownDomain(stage, &ret); // setup outer most threads. for (IterVar iv : stage->outermost_threads) { CHECK(iv->dom.defined()); diff --git a/src/schedule/message_passing.cc b/src/schedule/message_passing.cc new file mode 100644 index 000000000000..68d28df2c1dc --- /dev/null +++ b/src/schedule/message_passing.cc @@ -0,0 +1,343 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file message_passing.cc + * \brief The message passing domain. + */ +#include +#include +#include +#include "./message_passing.h" + +namespace tvm { +namespace schedule { + +using namespace arith; + +// result = ceil((a / b)), both a and b are positive integer +inline Expr DivCeil(Expr a, Expr b) { + return ir::Simplify((a + b - 1) / b); +} + +inline bool prove_equal(Expr lhs, Expr rhs) { + return is_zero(ir::Simplify(lhs - rhs)); +} + +void PassDownDomain(const Stage& stage, + std::unordered_map* p_state, + bool allow_missing) { + auto& state = *p_state; + // forwar iteration on relations + for (IterVarRelation rel : stage->relations) { + if (const SplitNode* r = rel.as()) { + if (!state.count(r->parent)) { + CHECK(allow_missing); + continue; + } + CHECK(!state.count(r->inner)); + const Range& range_parent = state.at(r->parent); + if (r->factor.defined()) { + state[r->inner] = Range::make_with_min_extent(0, r->factor); + if (r->outer->dom.defined()) { + state[r->outer] = r->outer->dom; + } else { + if (!state.count(r->outer)) { + state[r->outer] = Range::make_with_min_extent( + 0, DivCeil(range_parent->extent, r->factor)); + } else { + Expr outer_ext = DivCeil(range_parent->extent, r->factor); + Range outer_rng = state.at(r->outer); + bool match = is_zero(outer_rng->min); + if (!prove_equal(outer_ext, outer_rng->extent)) match = false; + CHECK(match) + << r->outer + << "IterVar is used in two places as outer scope," + << " cannot prove their extents are the same " + << outer_ext << " vs " << outer_rng->extent; + } + } + } else { + CHECK(r->outer->dom.defined()); + state[r->outer] = r->outer->dom; + state[r->inner] = Range::make_with_min_extent( + 0, DivCeil(range_parent->extent, r->outer->dom->extent)); + } + } else if (const FuseNode* r = rel.as()) { + if (!state.count(r->outer) || !state.count(r->inner)) { + CHECK(allow_missing); + continue; + } + const Range& range_outer = state.at(r->outer); + const Range& range_inner = state.at(r->inner); + state[r->fused] = Range::make_with_min_extent( + 0, range_outer->extent * range_inner->extent); + } else if (const RebaseNode* r = rel.as()) { + if (!state.count(r->parent)) { + CHECK(allow_missing); + continue; + } + state[r->rebased] = Range::make_with_min_extent( + 0, state.at(r->parent)->extent); + } else { + LOG(FATAL) << "unknown relation type"; + } + } +} + +void PassUpIndex(const Stage& stage, + const Map& dom_map, + std::unordered_map* p_state, + bool allow_missing) { + auto& state = *p_state; + for (size_t i = stage->relations.size(); i != 0; --i) { + IterVarRelation rel = stage->relations[i - 1]; + if (const SplitNode* s = rel.as()) { + if (!state.count(s->outer) || !state.count(s->inner)) { + CHECK(allow_missing); + continue; + } + Expr outer = state.at(s->outer); + Expr inner = state.at(s->inner); + Expr factor = dom_map.at(s->inner)->extent; + Expr parent_min = dom_map.at(s->parent)->min; + state[s->parent] = inner + outer * factor; + // add min if they exist + if (!is_zero(parent_min)) { + state[s->parent] = state[s->parent] + parent_min; + } + } else if (const FuseNode* s = rel.as()) { + if (!state.count(s->fused)) { + CHECK(allow_missing); + continue; + } + Expr value = state.at(s->fused); + Expr factor = dom_map.at(s->inner)->extent; + Expr outer_min = dom_map.at(s->outer)->min; + Expr inner_min = dom_map.at(s->inner)->min; + state[s->outer] = value / factor; + state[s->inner] = value % factor; + // add min if they exist + if (!is_zero(outer_min)) { + state[s->outer] = state[s->outer] + outer_min; + } + if (!is_zero(inner_min)) { + state[s->inner] = state[s->inner] + inner_min; + } + } else if (const RebaseNode* s = rel.as()) { + if (!state.count(s->rebased)) { + CHECK(allow_missing); + continue; + } + Expr value = state.at(s->rebased); + Expr parent_min = dom_map.at(s->parent)->min; + // add min if they exist + if (!is_zero(parent_min)) { + state[s->parent] = value + parent_min; + } else { + state[s->parent] = value; + } + } else { + LOG(FATAL) << "unknown relation type"; + } + } +} + +// Domain message passing. +void PassUpDomain(const SplitNode* s, + const std::unordered_map& dom_map, + const IntSet& outer, + const IntSet& inner, + IntSet* parent) { + if (dom_map.count(s->outer) && + dom_map.count(s->inner) && + dom_map.count(s->parent) && + outer.match_range(dom_map.at(s->outer)) && + inner.match_range(dom_map.at(s->inner))) { + *parent = IntSet::range(dom_map.at(s->parent)); + return; + } + Expr factor = dom_map.at(s->inner)->extent; + Expr parent_min = dom_map.at(s->parent)->min; + CHECK(outer.defined()); + CHECK(inner.defined()); + CHECK(factor.defined()); + *parent = EvalSet( + s->outer->var * factor + s->inner->var + parent_min, + {{s->outer, outer}, {s->inner, inner}}); +} + +void PassUpDomain(const FuseNode* s, + const std::unordered_map& dom_map, + const IntSet& fused, + IntSet* outer, + IntSet* inner) { + CHECK(dom_map.count(s->outer)); + CHECK(dom_map.count(s->inner)); + CHECK(dom_map.count(s->fused)); + + if (fused.match_range(dom_map.at(s->fused))) { + *outer = IntSet::range(dom_map.at(s->outer)); + *inner = IntSet::range(dom_map.at(s->inner)); + return; + } + Expr outer_min = dom_map.at(s->outer)->min; + Expr inner_min = dom_map.at(s->inner)->min; + + if (fused.is_single_point()) { + Expr value = fused.point_value(); + Expr factor = dom_map.at(s->inner)->extent; + Expr v_outer = value / factor; + Expr v_inner = value % factor; + if (!is_zero(outer_min)) v_outer = v_outer + outer_min; + if (!is_zero(inner_min)) v_inner = v_inner + inner_min; + *outer = IntSet::single_point(v_outer); + *inner = IntSet::single_point(v_inner); + } else { + LOG(WARNING) << "use fallback inference rule in fuse"; + // simply use the entire set, this rule can be enhanced. + *outer = IntSet::range(dom_map.at(s->outer)); + *inner = IntSet::range(dom_map.at(s->inner)); + return; + } +} + +void PassUpDomain(const RebaseNode* s, + const std::unordered_map& dom_map, + const IntSet& rebased, + IntSet* parent) { + CHECK(dom_map.count(s->parent)); + if (rebased.match_range(dom_map.at(s->rebased))) { + *parent = IntSet::range(dom_map.at(s->parent)); + return; + } + Expr parent_min = dom_map.at(s->parent)->min; + *parent = EvalSet(s->rebased->var + parent_min, + {{s->rebased, rebased}}); +} + +void PassUpDomain(const Stage& stage, + const std::unordered_map& dom_map, + std::unordered_map* p_state) { + auto& state = *p_state; + for (size_t i = stage->relations.size(); i != 0; --i) { + IterVarRelation rel = stage->relations[i - 1]; + if (const SplitNode* r = rel.as()) { + IntSet parent; + PassUpDomain(r, dom_map, + state.at(r->outer), state.at(r->inner), + &parent); + state[r->parent] = parent; + } else if (const FuseNode* r = rel.as()) { + IntSet outer, inner; + PassUpDomain(r, dom_map, + state.at(r->fused), + &outer, &inner); + state[r->outer] = outer; + state[r->inner] = inner; + } else if (const RebaseNode* r = rel.as()) { + IntSet parent; + PassUpDomain(r, dom_map, + state.at(r->rebased), + &parent); + state[r->parent] = parent; + } else { + LOG(FATAL) << "unknown relation type"; + } + } +} + +// Pass up bit mask with or relation. +void PassUpBitMaskOr(const Stage& stage, + std::unordered_map* p_state, + bool allow_missing) { + auto& state = *p_state; + for (size_t i = stage->relations.size(); i != 0; --i) { + IterVarRelation rel = stage->relations[i - 1]; + if (const SplitNode* s = rel.as()) { + if (!state.count(s->inner) && !state.count(s->outer)) { + CHECK(allow_missing); + continue; + } + int res = 0; + if (!state.count(s->parent)) res |= state[s->parent]; + if (!state.count(s->inner)) res |= state[s->inner]; + if (!state.count(s->outer)) res |= state[s->outer]; + state[s->parent] = res; + } else if (const FuseNode* s = rel.as()) { + if (!state.count(s->fused)) { + CHECK(allow_missing); + continue; + } + if (!state.count(s->outer)) { + state[s->outer] = state[s->fused]; + } else { + state[s->outer] |= state[s->fused]; + } + if (!state.count(s->inner)) { + state[s->inner] = state[s->fused]; + } else { + state[s->inner] |= state[s->fused]; + } + } else if (const RebaseNode* s = rel.as()) { + if (!state.count(s->rebased)) { + CHECK(allow_missing); + continue; + } + if (!state.count(s->parent)) { + state[s->parent] = state[s->rebased]; + } else { + state[s->parent] |= state[s->rebased]; + } + } else { + LOG(FATAL) << "unknown relation type"; + } + } +} + +void PassDownBitMaskOr(const Stage& stage, + std::unordered_map* p_state, + bool allow_missing) { + auto& state = *p_state; + for (IterVarRelation rel : stage->relations) { + if (const SplitNode* s = rel.as()) { + if (!state.count(s->parent)) { + CHECK(allow_missing); + continue; + } + if (!state.count(s->outer)) { + state[s->outer] = state.at(s->parent); + } else { + state[s->outer] |= state.at(s->parent); + } + if (!state.count(s->inner)) { + state[s->inner] = state.at(s->parent); + } else { + state[s->inner] |= state.at(s->parent); + } + } else if (const FuseNode* s = rel.as()) { + if (!state.count(s->outer) && !state.count(s->inner)) { + CHECK(allow_missing); + continue; + } + int res = 0; + if (state.count(s->outer)) res |= state.at(s->outer); + if (state.count(s->inner)) res |= state.at(s->inner); + if (state.count(s->fused)) res |= state.at(s->fused); + state[s->fused] = res; + } else if (const RebaseNode* s = rel.as()) { + if (!state.count(s->parent)) { + CHECK(allow_missing); + continue; + } + if (!state.count(s->rebased)) { + state[s->rebased] = state.at(s->parent); + } else { + state[s->rebased] |= state.at(s->parent); + } + } else { + LOG(FATAL) << "unknown relation type"; + } + } +} + +} // namespace schedule +} // namespace tvm diff --git a/src/schedule/message_passing.h b/src/schedule/message_passing.h new file mode 100644 index 000000000000..5b7cf9d2400f --- /dev/null +++ b/src/schedule/message_passing.h @@ -0,0 +1,81 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file message_passing.h + * \brief Common utilities to do message passing + * on the schedule hyper graph. + */ +#ifndef TVM_SCHEDULE_MESSAGE_PASSING_H_ +#define TVM_SCHEDULE_MESSAGE_PASSING_H_ + +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace schedule { +/*! + * \brief Downward inference of domain of each IterVar. + * Caller set the range of the root, then the function + * propagates it towards the leaves. + * + * \param stage The stage to operate on. + * \param p_state The state of the message passing. + * \param allow_missing Whether allow missing value. + */ +void PassDownDomain( + const Stage& stage, + std::unordered_map* p_state, + bool allow_missing = false); + +/*! + * \param Upward inference of index of each IterVar. + * given index assignement of the leaves, + * + * \param stage The stage to operate on. + * \param dom_map The domain map of each iteration variable's domain. + * \param p_state The index state of each IterVar. + * \param allow_missing Whether allow missing value. + */ +void PassUpIndex(const Stage& stage, + const Map& dom_map, + std::unordered_map* p_state, + bool allow_missing = false); + +/*! + * \param Upward inference of domain set of each IterVar. + * given domain assignment of the leaves, + * + * \param stage The stage to operate on. + * \param dom_map The domain map of each iteration variable's maximum domain. + * \param p_state The index state of each IterVar. + */ +void PassUpDomain(const Stage& stage, + const std::unordered_map& dom_map, + std::unordered_map* p_state); + +/*! + * \brief Upward message passing of bitmask with or relation. + * \param stage The stage to operate on. + * \param p_state The index state of each IterVar. + * \param allow_missing Whether allow missing value. + */ +void PassUpBitMaskOr(const Stage& stage, + std::unordered_map* p_state, + bool allow_missing = false); + +/*! + * \brief Downward message passing of bitmask with or relation. + * \param stage The stage to operate on. + * \param p_state The index state of each IterVar. + * \param allow_missing Whether allow missing value. + */ +void PassDownBitMaskOr(const Stage& stage, + std::unordered_map* p_state, + bool allow_missing = false); +} // namespace schedule +} // namespace tvm +#endif // TVM_SCHEDULE_MESSAGE_PASSING_H_ diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index a6c193e876dd..b577f0a431a7 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -7,6 +7,7 @@ #include #include #include +#include "./message_passing.h" namespace tvm { @@ -139,7 +140,6 @@ Tensor Schedule::cache_write(const Tensor& tensor, return cache_tensor; } - void RebaseNonZeroMinLoop(const Schedule& sch) { std::unordered_map rebase_map; std::unordered_map attach_mark; @@ -244,4 +244,151 @@ void Schedule::normalize() { InjectInline(*this); } +// Handle reduction factor. +Tensor Schedule::rfactor(const Tensor& tensor, + const IterVar& axis) { + using ir::Reduce; + CHECK_EQ(axis->iter_type, kCommReduce) + << "Can only factor reduction axis"; + Stage reduce_stage = operator[](tensor->op); + const ComputeOpNode* compute_op = reduce_stage->op.as(); + CHECK(compute_op) << "Can only factor ComputeOp"; + ArrayNode* leaf_vars = reduce_stage->leaf_iter_vars.CopyOnWrite(); + { + size_t axis_pos = FindNodeRef(leaf_vars, axis); + CHECK_NE(axis_pos, leaf_vars->data.size()) + << "Cannot find IterVar " << axis << " in leaf iter vars"; + } + // Find touched reduction axis. + std::unordered_map touch_map; + touch_map[axis] = 1; + schedule::PassUpBitMaskOr(reduce_stage, &touch_map, true); + schedule::PassDownBitMaskOr(reduce_stage, &touch_map, true); + // Verify normal axis are not touched. + for (IterVar iv : compute_op->axis) { + CHECK(!touch_map.count(iv)) + << "Factor axis touches normal axis."; + } + // Get the replace index + std::unordered_map dom_map; + std::unordered_map value_map; + for (IterVar iv : compute_op->reduce_axis) { + if (touch_map.count(iv)) dom_map[iv] = iv->dom; + } + schedule::PassDownDomain(reduce_stage, &dom_map, true); + for (IterVar iv : reduce_stage->leaf_iter_vars) { + if (touch_map.count(iv)) { + Range dom = dom_map.at(iv); + if (is_one(dom->extent)) { + value_map[iv] = dom->min; + } else { + value_map[iv] = iv->var; + } + } + } + schedule::PassUpIndex(reduce_stage, dom_map, &value_map, true); + // Get the factored op node. + auto n = std::make_shared(); + n->name = compute_op->name + ".rf"; + { + // axis relacement. + auto iv_node = std::make_shared(); + iv_node->dom = dom_map.at(axis); + CHECK(is_zero(iv_node->dom->min)) + << "Can only factor reduction domain starting from 0"; + iv_node->var = axis->var; + iv_node->iter_type = kDataPar; + n->axis.push_back(IterVar(iv_node)); + + for (IterVar iv : compute_op->axis) { + n->axis.push_back(iv); + } + } + // predicate generation, copy not touched axis. + std::unordered_map vsub; + Expr predicate; + for (IterVar iv : compute_op->reduce_axis) { + if (!touch_map.count(iv)) { + n->reduce_axis.push_back(iv); + } else { + CHECK(value_map.count(iv)); + Expr index = value_map.at(iv); + vsub[iv->var.get()] = index; + if (!index.same_as(iv->var)) { + Expr cond = (index < dom_map.at(iv)->extent); + if (predicate.defined()) { + predicate = predicate && cond; + } else { + predicate = cond; + } + } + } + } + // Copy touched axis. + for (IterVar iv : reduce_stage->leaf_iter_vars) { + if (touch_map.count(iv) && !iv.same_as(axis)) { + CHECK_EQ(iv->iter_type, kCommReduce); + auto ncpy = std::make_shared(*iv.operator->()); + ncpy->dom = dom_map.at(iv); + n->reduce_axis.push_back(IterVar(ncpy)); + } + } + const Reduce* reduce = compute_op->body.as(); + CHECK(reduce) << "Can only rfactor non-inline reductions"; + n->body = Reduce::make(reduce->op, + VarReplacer(vsub).Mutate(reduce->source), + n->reduce_axis, + predicate); + // refresh relations, keep the un-touched relations. + Array rels; + for (IterVarRelation rel : reduce_stage->relations) { + bool touched = false; + if (const SplitNode* r = rel.as()) { + if (touch_map.count(r->parent)) touched = true; + } else if (const FuseNode* r = rel.as()) { + if (touch_map.count(r->fused)) touched = true; + } else if (const RebaseNode* r = rel.as()) { + if (touch_map.count(r->parent)) touched = true; + } else { + LOG(FATAL) << "unknown relation type"; + } + if (!touched) { + rels.push_back(rel); + } + } + // initialize the factored stage. + Operation factor_op(n); + ArrayNode* stages = (*this)->stages.CopyOnWrite(); + size_t stage_pos = FindNodeRef(stages, reduce_stage); + Stage factor_stage = Stage(factor_op); + factor_stage->relations = rels; + CHECK_LT(stage_pos, stages->data.size()); + stages->data.insert(stages->data.begin() + stage_pos, + factor_stage.node_); + (*this)->stage_map.Set(factor_op, factor_stage); + // Replace the old reduction. + IterVar repl_red_axis = reduce_axis( + dom_map.at(axis), axis->var->name_hint + ".v"); + Tensor factor_tensor = factor_op.output(0); + Tensor old_tensor = reduce_stage->op.output(0); + Tensor repl_tensor = compute(old_tensor->shape, [&](const Array& i) { + Array indices; + indices.push_back(repl_red_axis->var); + for (Var v : i) { + indices.push_back(v); + } + return Reduce::make( + reduce->op, factor_tensor(indices), {repl_red_axis}, const_true()); + }, old_tensor->op->name + ".repl"); + + std::unordered_map vmap; + vmap[old_tensor] = repl_tensor; + ReplaceDataFlow((*this)->stages, &vmap); + // revamp the reduction stage. + reduce_stage->op = repl_tensor->op; + reduce_stage->all_iter_vars = repl_tensor->op->root_iter_vars(); + reduce_stage->leaf_iter_vars = reduce_stage->all_iter_vars; + reduce_stage->relations = Array(); + return factor_tensor; +} } // namespace tvm diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc index bbcd832e269e..318e9b057a0d 100644 --- a/src/schedule/schedule_lang.cc +++ b/src/schedule/schedule_lang.cc @@ -43,11 +43,18 @@ void CheckSplit(StageNode* self, IterVar parent, IterVar outer) { << "Cannot split on axis[0] of scan update"; } if (outer.defined()) { - CHECK_EQ(outer->iter_type, kThreadIndex) - << "outer in split have to be ThreadIndex"; - CHECK_EQ(parent->iter_type, kDataPar) - << "Split by by kThreadIndex requires kDataPar IterVar " - << " given " << IterVarType2String(parent->iter_type); + if (outer->iter_type == kThreadIndex) { + CHECK_EQ(parent->iter_type, kDataPar) + << "Split by by kThreadIndex requires kDataPar IterVar " + << " given " << IterVarType2String(parent->iter_type); + } else if (outer->iter_type == kCommReduce) { + CHECK_EQ(parent->iter_type, kCommReduce) + << "Split by by kCommReduce requires kCommReduce IterVar " + << " given " << IterVarType2String(parent->iter_type); + } else { + LOG(FATAL) << "Cannot take " << IterVarType2String(parent->iter_type) + << " as outer IterVar"; + } } else { CHECK(parent->iter_type == kDataPar || parent->iter_type == kCommReduce || @@ -73,18 +80,6 @@ void Split(StageNode* self, IterVar parent, } // namespace -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const StageNode *op, IRPrinter *p) { - p->stream << "stage(" - << op->op - << ")"; -}); - -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const IterVarAttrNode *op, IRPrinter *p) { - p->stream << IterVarType2String(op->iter_type); - }); - Stage::Stage(Operation op) { auto n = std::make_shared(); n->op = op; @@ -374,4 +369,42 @@ TVM_REGISTER_NODE_TYPE(FuseNode); TVM_REGISTER_NODE_TYPE(RebaseNode); TVM_REGISTER_NODE_TYPE(ScheduleNode); +// Printer +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const StageNode *op, IRPrinter *p) { + p->stream << "stage(" << op->origin_op->name << ", " << op << ")"; +}) +.set_dispatch([](const IterVarAttrNode *op, IRPrinter *p) { + p->stream << IterVarType2String(op->iter_type); +}) +.set_dispatch([](const SplitNode *op, IRPrinter *p) { + p->stream << "split(parent="; + p->print(op->parent); + p->stream << ", outer="; + p->print(op->outer); + p->stream << ", inner="; + p->print(op->inner); + p->stream << ')'; +}) +.set_dispatch([](const FuseNode *op, IRPrinter *p) { + p->stream << "split("; + p->stream << "outer="; + p->print(op->outer); + p->stream << ", inner="; + p->print(op->inner); + p->stream << ", fused="; + p->print(op->fused); + p->stream << ')'; +}) +.set_dispatch([](const RebaseNode *op, IRPrinter *p) { + p->stream << "rebase("; + p->stream << "parent="; + p->print(op->parent); + p->stream << ", rebased="; + p->print(op->rebased); + p->stream << ')'; +}) +.set_dispatch([](const ScheduleNode *op, IRPrinter *p) { + p->stream << "schedule(" << op << ")"; + }); } // namespace tvm diff --git a/tests/python/integration/test_reduce.py b/tests/python/integration/test_reduce.py index 4c341444fd39..726cd3f11ed3 100644 --- a/tests/python/integration/test_reduce.py +++ b/tests/python/integration/test_reduce.py @@ -7,7 +7,7 @@ def test_sum(): m = tvm.Var('m') A = tvm.placeholder((n, m), name='A') k = tvm.reduce_axis((0, m)) - B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name='B') + B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k, where=(i>1)), name='B') # schedule s = tvm.Schedule(B.op) # create iter var and assign them tags. @@ -28,14 +28,17 @@ def check_device(device, host="stackvm"): args=[A, B], target=device, target_host=host, name="mysum") + print(fsum.imported_modules[0].get_source()) # launch the kernel. n = 1028 m = 129 a = tvm.nd.array(np.random.uniform(size=(n, m)).astype(A.dtype), ctx) b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx) fsum(a, b) + res = np.sum(a.asnumpy(), axis=1) + res[:2] = 0 np.testing.assert_allclose( - b.asnumpy(), np.sum(a.asnumpy(), axis=1), rtol=1e-4) + b.asnumpy(), res, rtol=1e-4) if tvm.module.enabled("opencl"): tvm.module.init_opencl() @@ -43,5 +46,38 @@ def check_device(device, host="stackvm"): check_device("cuda") check_device("opencl") + +def test_rfactor(): + n = tvm.convert(1027) + A = tvm.placeholder((n,), name='A') + k = tvm.reduce_axis((0, n)) + B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k, where=(i>1)), name='B') + kf = tvm.reduce_axis((0, 4)) + # schedule + s = tvm.Schedule(B.op) + _, ki = s[B].split(k, outer=kf) + BF = s.rfactor(B, kf) + s[BF].parallel(BF.op.axis[0]) + # one line to build the function. + def check_target(target="llvm"): + if not tvm.codegen.enabled(target): + return + ctx = tvm.cpu(0) + fapi = tvm.lower(s, args=[A, B]) + fsum = tvm.build(fapi, + target=target, + name="mysum") + # launch the kernel. + n = 1027 + a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx) + fsum(a, b) + res = np.sum(a.asnumpy(), axis=0) + np.testing.assert_allclose( + b.asnumpy(), res, rtol=1e-4) + + check_target() + if __name__ == "__main__": + test_rfactor() test_sum() diff --git a/tests/python/unittest/test_lang_schedule.py b/tests/python/unittest/test_lang_schedule.py index b4ca987ed30b..c38dc59a019e 100644 --- a/tests/python/unittest/test_lang_schedule.py +++ b/tests/python/unittest/test_lang_schedule.py @@ -91,8 +91,33 @@ def test_vectorize(): assert s[T].iter_var_attrs[xi].iter_type == UNROLL assert s[T].iter_var_attrs[yi].iter_type == VECTORIZE +def test_rfactor(): + n = tvm.Var('n') + k1 = tvm.reduce_axis((0, n), name="k1") + k2 = tvm.reduce_axis((0, n), name="k2") + A = tvm.placeholder((n, n, n), name='A') + B = tvm.compute((n, ), lambda i: tvm.sum(A[i, k1, k2], axis=[k1, k2])) + # normal schedule + s = tvm.Schedule(B.op) + BF = s.rfactor(B, k1) + assert(tuple(BF.shape) == (n, n)) + assert(set(BF.op.body.axis) == set([k2])) + assert(s[B].op.body.axis[0].dom.extent == n) + assert(len(s[B].all_iter_vars) == 2) + # schedule with splot + s = tvm.Schedule(B.op) + ko, ki = s[B].split(k1, factor=4) + xo, xi = s[B].split(B.op.axis[0], factor=8) + BF = s.rfactor(B, ki) + assert(BF.shape[0].value == 4) + assert(BF.shape[1] == n) + assert(BF.op.body.axis[0] == k2) + assert(BF.op.body.axis[1].var == ko.var) + assert(s[B].op.body.axis[0].dom.extent.value == 4) + if __name__ == "__main__": + test_rfactor() test_schedule_create() test_reorder() test_tile() diff --git a/tests/python/unittest/test_schedule_bound_inference.py b/tests/python/unittest/test_schedule_bound_inference.py index e60edd660a8a..15b18ec2f60b 100644 --- a/tests/python/unittest/test_schedule_bound_inference.py +++ b/tests/python/unittest/test_schedule_bound_inference.py @@ -100,7 +100,24 @@ def computeB(ii, jj): assert(bounds[A.op.axis[0]].extent.value == 3) assert(bounds[A.op.axis[1]].extent.value == 3) +def test_bound_rfactor(): + n = tvm.Var('n') + A = tvm.placeholder((n,), name='A') + k = tvm.reduce_axis((0, n)) + B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k, where=(i>1)), name='B') + kf = tvm.reduce_axis((0, 4)) + # schedule + s = tvm.Schedule(B.op) + _, ki = s[B].split(k, outer=kf) + BF = s.rfactor(B, kf) + s.normalize() + bounds = tvm.schedule.InferBound(s) + assert(bounds[BF.op.axis[0]].extent.value == 4) + assert(bounds[BF.op.axis[1]].extent.value == 1) + + if __name__ == "__main__": + test_bound_rfactor() test_bound_blur() test_bound_conv1d() test_bound_scan()