From da6a07cdfa379f22f32fa535f8ddddedb1fc9f75 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Fri, 4 May 2018 16:54:07 +0800 Subject: [PATCH 1/7] [TensorOp] Interface of TensorOp. --- include/tvm/expr.h | 2 + include/tvm/operation.h | 68 +++- python/tvm/api.py | 86 ++++- python/tvm/tensor.py | 7 + src/api/api_lang.cc | 11 + src/op/compute_op.cc | 35 ++ src/op/compute_op.h | 17 +- src/op/tensor_op.cc | 385 ++++++++++++++++++++++ src/op/tensorize.cc | 45 --- src/schedule/schedule_dataflow_rewrite.cc | 161 ++++++++- tests/python/unittest/test_lang_tensor.py | 32 ++ 11 files changed, 793 insertions(+), 56 deletions(-) create mode 100644 src/op/tensor_op.cc diff --git a/include/tvm/expr.h b/include/tvm/expr.h index a199d656caf8..8bb432ab6641 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -118,6 +118,8 @@ class Range : public HalideIR::IR::Range { TVM_DLL static Range make_by_min_extent(Expr min, Expr extent); }; +using Region = Array; + /*! * \brief Type of iteration variable. * Each IterVar have a specific type. diff --git a/include/tvm/operation.h b/include/tvm/operation.h index c11242c0a55d..5eeac93dfc94 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -49,7 +49,7 @@ class OperationNode : public FunctionBaseNode { } /*! * \return The list of iteration variable at root - * \note root_iter_vars dedides the shape of the outputs. + * \note root_iter_vars decides the shape of the outputs. */ virtual Array root_iter_vars() const = 0; /*! @@ -182,6 +182,70 @@ class PlaceholderOpNode : public OperationNode { TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode, OperationNode); }; +class TensorOpNode : public OperationNode { + public: + Array axis; + + Array tensor_axis; + + Array reduce_axis; + + Array inputs; + + Array input_regions; + + TensorIntrin intrin; + + /*! \brief constructor */ + TensorOpNode() {} + + // override functions + int num_outputs() const final; + Array root_iter_vars() const final; + Type output_dtype(size_t i) const final; + Array output_shape(size_t i) const final; + Array InputTensors() const final; + Operation ReplaceInputs( + const Operation& self, + const std::unordered_map& rmap) const final; + void PropBoundToInputs( + const Operation& self, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const final; + void GatherBound( + const Operation& self, + const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const final; + Stmt BuildRealize( + const Stage& stage, + const std::unordered_map& realize_map, + const Stmt& body) const final; + Stmt BuildProvide( + const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const final; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("name", &name); + v->Visit("tag", &tag); + v->Visit("axis", &axis); + v->Visit("tensor_axis", &tensor_axis); + v->Visit("reduce_axis", &reduce_axis); + v->Visit("inputs", &inputs); + } + + static Operation make(std::string name, + std::string tag, + Array axis, + Array reduce_axis, + Array inputs, + Array input_regions, + TensorIntrin intrin); + + static constexpr const char* _type_key = "TensorOp"; + TVM_DECLARE_NODE_TYPE_INFO(TensorOpNode, OperationNode); +}; + /*! * \brief A Compute op that compute a tensor on certain domain. */ @@ -326,7 +390,7 @@ class ExternOpNode : public OperationNode { public: /*! \brief The input tensors */ Array inputs; - /*! \brief Symbolic placeholder representationinputs */ + /*! \brief Symbolic placeholder representation of inputs */ Array input_placeholders; /*! \brief Symbolic placeholder representation of outputs */ Array output_placeholders; diff --git a/python/tvm/api.py b/python/tvm/api.py index 223e73eeb596..1c1dbfaa1532 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -3,6 +3,7 @@ from __future__ import absolute_import as _abs from numbers import Integral as _Integral +from collections import namedtuple from ._ffi.base import string_types from ._ffi.node import register_node, NodeBase @@ -14,6 +15,7 @@ from . import _api_internal from . import make as _make from . import expr as _expr +from . import stmt as _stmt from . import tensor as _tensor from . import schedule as _schedule from . import container as _container @@ -335,6 +337,86 @@ def scan(init, update, state_placeholder, inputs=None, name="scan", tag="", attr return res[0] if len(res) == 1 else res +def tensor_op(out_dims, + in_dims, # pylint: disable=unused-argument + finputs, + intrin, + raxis=None, + name='tensor_op', + tag=""): + """Construct new tensors with intrinsic. + + Parameters + ---------- + out_dims: tuple + The dimensions out of the tensorized region, which can be + scheduled through `reorder`, `split`. + + in_dims: tuple + The dimensions inside of the tensorized region, which cannot + be manipulated. + + finputs: lambda function of out_dims -> list of TensorSlice + Specifies involved regions of input tensors. + + tensor_intrin : TensorIntrin + The tensor intrinsic used for computation. + + raxis : IterVar + An iteration variable representing the value. + + name: str, optional + The name hint of the tensor + + tag: str, optional + Additonal tag information about the compute. + """ + def _get_region(tslice): + region = [] + for idx in tslice.indices: + if isinstance(idx, slice): + assert idx.step is None + region.append(Range(idx.start, idx.stop)) + else: + if isinstance(idx, _schedule.IterVar): + begin = idx.var + else: + begin = idx + region.append(_make.range_by_min_extent(begin, 1)) + return region + + if _tag.TagScope.current is not None: + if tag != "": + raise ValueError("nested tag is not allowed for now") + tag = _tag.TagScope.current.tag + + code = finputs.__code__ + if finputs.__code__.co_argcount == 0: + arg_names = ["i%d" % i for i in range(ndim)] + else: + arg_names = code.co_varnames[:code.co_argcount] + + if len(out_dims) != len(arg_names): + raise ValueError("finputs do not match dimension, ndim=%d" % out_dims) + + out_var = [_IterVar((0, extent), arg_name, 0) for arg_name, extent in zip(arg_names, out_dims)] + if isinstance(raxis, _schedule.IterVar): + raxis = [raxis] + if raxis is None: + raxis = [] + tensor_regions = finputs(*[v.var for v in out_var]) + + op = _api_internal._TensorOp(name, + tag, + out_var, + raxis, + [x.tensor for x in tensor_regions], + [_get_region(x) for x in tensor_regions], + intrin) + # only support single output + return op.output(0) + + def extern(shape, inputs, fcompute, @@ -529,14 +611,14 @@ def decl_buffer(shape, dtype = float32 if dtype is None else dtype strides = () if strides is None else strides if offset_factor != 0 and elem_offset is None: - elem_offset = var('%s_elem_offset' % name, shape[0].dtype) + shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32" + elem_offset = var('%s_elem_offset' % name, shape_dtype) if data is None: data = var(name, "handle") return _api_internal._Buffer( data, dtype, shape, strides, elem_offset, name, scope, data_alignment, offset_factor) - def _IterVar(dom, name, iter_type, thread_tag=''): """Internal function to create IterVar diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py index f0d60f514a37..254e17710d47 100644 --- a/python/tvm/tensor.py +++ b/python/tvm/tensor.py @@ -106,6 +106,7 @@ def name(self): return "%s.v%d" % (op.name, self.value_index) + class Operation(NodeBase): """Represent an operation that generate a tensor""" @@ -168,3 +169,9 @@ def scan_axis(self): class ExternOp(Operation): """Extern operation.""" pass + + +@register_node +class TensorOp(Operation): + """Tensor operation.""" + pass diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 8ca49f19baec..8cc34dd74448 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -278,6 +278,17 @@ TVM_REGISTER_API("_ScanOp") args[7]); }); +TVM_REGISTER_API("_TensorOp") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = TensorOpNode::make(args[0], + args[1], + args[2], + args[3], + args[4], + args[5], + args[6]); + }); + TVM_REGISTER_API("_ExternOp") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = ExternOpNode::make(args[0], diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index 6100c957e473..0c40882c0be2 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -13,6 +13,7 @@ #include "compute_op.h" #include "op_util.h" #include "../schedule/message_passing.h" +#include "../arithmetic/compute_expr.h" namespace tvm { @@ -542,4 +543,38 @@ static void VerifyComputeOp(const ComputeOpNode* op) { v.Run(); } +Stmt TransformUpdate(const Stage& stage, + const std::unordered_map& dom_map, + const ComputeLoopNest& n, + Stmt body, + Stmt update) { + Array conds; + std::unordered_set banned; + for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { + IterVar iv = stage->leaf_iter_vars[i]; + auto iit = stage->iter_var_attrs.find(iv); + if (iit != stage->iter_var_attrs.end()) { + const IterVarAttr& attr = (*iit).second; + if (attr->iter_type == kTensorized) { + break; + } + } + if (iv->iter_type == kCommReduce) { + auto vit = dom_map.find(iv); + CHECK(vit != dom_map.end()); + const Range& vrange = vit->second; + conds.push_back(likely(iv->var > vrange->min)); + banned.insert(iv->var.get()); + } + } + for (const Expr& pred : n.main_predicates) { + if (ir::ExprUseVar(pred, banned)) { + LOG(FATAL) << "Tensorize update transform failed, the condition " + << pred << " has a conflict with the reset condition"; + } + } + + return IfThenElse::make(arith::ComputeReduce(conds, const_true(1)), + update, body); +} } // namespace tvm diff --git a/src/op/compute_op.h b/src/op/compute_op.h index 996764c6cdc1..87b0814c1ad9 100644 --- a/src/op/compute_op.h +++ b/src/op/compute_op.h @@ -14,7 +14,7 @@ namespace tvm { // loop nest structure for general compute -// This the the loop nest structured used in compute. +// This the loop nest structured used in compute. // Does not include the loop body. struct ComputeLoopNest { // The common number of loops between init and main @@ -73,6 +73,21 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop); + +/*! + * \brief Transform the update part when there is no init func in tensorizing + * \param stage The stage for tensorizing. + * \param dom_map The range of each iter var. + * \param n The loop nest structured used in compute. + * \param body The body func in tensorize intrin + * \param update The update func in tensorize intrin + * \return Transformed result. + */ +Stmt TransformUpdate(const Stage& stage, + const std::unordered_map& dom_map, + const ComputeLoopNest& n, + Stmt body, + Stmt update); } // namespace tvm #endif // TVM_OP_COMPUTE_OP_H_ diff --git a/src/op/tensor_op.cc b/src/op/tensor_op.cc new file mode 100644 index 000000000000..13ca38b63529 --- /dev/null +++ b/src/op/tensor_op.cc @@ -0,0 +1,385 @@ +/*! + * Copyright (c) 2017 by Contributors + * \brief Compute Op. + * \file compute_op.cc + */ +#include +#include +#include +#include +#include +#include +#include "./op_util.h" +#include "./compute_op.h" +#include "../arithmetic/compute_expr.h" + +namespace tvm { +using namespace ir; +// TensorOpNode +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const TensorOpNode *op, IRPrinter *p) { + p->stream << "tensor_op(" << op->name << ", " << op << ")"; + }); + +TVM_REGISTER_NODE_TYPE(TensorOpNode); + +int TensorOpNode::num_outputs() const { + return static_cast(this->intrin->buffers.size() - this->inputs.size()); +} + +Array TensorOpNode::root_iter_vars() const { + Array ret = axis; + for (IterVar iv : tensor_axis) { + ret.push_back(iv); + } + for (IterVar iv : reduce_axis) { + ret.push_back(iv); + } + return ret; +} + +Type TensorOpNode::output_dtype(size_t i) const { + return this->intrin->buffers[this->inputs.size() + i]->dtype; +} + +Array TensorOpNode::output_shape(size_t i) const { + Array shape; + for (const auto& ivar : this->axis) { + shape.push_back(ivar->dom->extent); + } + size_t index = this->inputs.size() + i; + for (const auto& dim : this->intrin->buffers[index]->shape) { + shape.push_back(dim); + } + return shape; +} + + +Operation TensorOpNode::make(std::string name, + std::string tag, + Array axis, + Array reduce_axis, + Array inputs, + Array input_regions, + TensorIntrin intrin) { + auto n = std::make_shared(); + n->name = name; + n->tag = tag; + n->axis = axis; + + Array tout_shape = intrin->buffers[inputs.size()]->shape; + for (size_t i = 0; i < tout_shape.size(); ++i) { + Var var("ax" + std::to_string(i)); + Range range = Range::make_by_min_extent(make_zero(Int(32)), tout_shape[i]); + n->tensor_axis.push_back(IterVarNode::make(range, var, kOpaque)); + } + + n->reduce_axis = reduce_axis; + n->inputs = inputs; + n->input_regions = input_regions; + n->intrin = intrin; + return Operation(n); +} + +Array TensorOpNode::InputTensors() const { + return inputs; +} + +Operation TensorOpNode::ReplaceInputs( + const Operation& self, + const std::unordered_map& rmap) const { + CHECK_EQ(self.operator->(), this); + auto n = std::make_shared(*this); + auto intrin = std::make_shared(*(this->intrin.operator->())); + intrin->body = op::ReplaceTensor(this->intrin->body, rmap); + if (intrin->reduce_init.defined()) { + intrin->reduce_init = op::ReplaceTensor(this->intrin->reduce_init, rmap); + } + if (intrin->reduce_update.defined()) { + intrin->reduce_update = op::ReplaceTensor(this->intrin->reduce_update, rmap); + } + for (size_t i = 0; i < n->inputs.size(); ++i) { + Tensor t = n->inputs[i]; + if (rmap.count(t)) { + n->inputs.Set(i, rmap.at(t)); + } + } + + if (intrin->body.same_as(n->intrin->body) && + intrin->reduce_init.same_as(n->intrin->reduce_init) && + intrin->reduce_update.same_as(n->intrin->reduce_update) && + inputs.same_as(n->inputs)) { + return self; + } else { + n->intrin = TensorIntrin(intrin); + return Operation(n); + } +} + +void TensorOpNode::PropBoundToInputs( + const Operation& self, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const { + for (size_t i = 0; i < this->inputs.size(); ++i) { + Tensor t = this->inputs[i]; + Region region = input_regions[i]; + + auto it = out_dom_map->find(t); + if (it == out_dom_map->end()) continue; + TensorDom& dom = it->second; + for (size_t j = 0; j < t.ndim(); ++j) { + dom.data[j].emplace_back(EvalSet(region[j], dom_map)); + } + } +} + +void TensorOpNode::GatherBound( + const Operation& self, + const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const { + const TensorDom& tdom = tensor_dom.at(self.output(0)); + for (size_t i = 0; i < this->axis.size(); ++i) { + Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom); + CHECK(!out_dom_map->count(this->axis[i])); + (*out_dom_map)[this->axis[i]] = r; + } + for (size_t i = 0; i < this->reduce_axis.size(); ++i) { + CHECK(!out_dom_map->count(this->reduce_axis[i])); + (*out_dom_map)[this->reduce_axis[i]] = this->reduce_axis[i]->dom; + } +} + +Stmt TensorOpNode::BuildRealize( + const Stage& stage, + const std::unordered_map& realize_map, + const Stmt& body) const { + CHECK_EQ(stage->op.get(), this); + HalideIR::Internal::Region bounds; + for (IterVar iv : this->axis) { + bounds.push_back(realize_map.at(iv)); + } + size_t out_buff_idx = this->intrin->buffers.size(); + for (const Expr extent : this->intrin->buffers[out_buff_idx - 1]->shape) { + bounds.push_back(Range(0, extent)); + } + Stmt realize = body; + for (int i = this->num_outputs(); i > 0; --i) { + Tensor t = stage->op.output(i-1); + realize = ir::Realize::make(t->op, t->value_index, + t->dtype, bounds, const_true(), realize); + // alignment requirement, only useful for compute + for (size_t i = 0; i < this->axis.size(); ++i) { + auto it = stage->iter_var_attrs.find(this->axis[i]); + if (it != stage->iter_var_attrs.end()) { + IterVarAttr attr = (*it).second; + if (attr->dim_align_factor != 0) { + Array tuple = {static_cast(i), + attr->dim_align_factor, + attr->dim_align_offset}; + realize = ir::AttrStmt::make( + t, ir::attr::buffer_dim_align, + Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), + realize); + } + } + } + } + return realize; +} + +ComputeLoopNest MakeLoopNest( + const TensorOpNode* self, + const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) { + CHECK_EQ(stage->op.operator->(), self); + ComputeLoopNest ret; + // make main loop nest + ret.main_nest = op::MakeLoopNest( + stage, dom_map, 0, false, std::unordered_set(), &ret.main_vmap, + debug_keep_trivial_loop); + ret.main_predicates = schedule::MakeBoundCheck( + stage, dom_map, ret.main_vmap, false, + std::unordered_set()); + for (auto& e : ret.main_predicates) { + e = likely(e); + } + if (stage->store_predicate.defined()) { + ret.main_predicates.push_back(stage->store_predicate); + } + if (self->reduce_axis.size() != 0) { + // try to find the location to insert the initialization. + // Fuse the initialization and provide loop when possible. + std::unordered_map update_state; + for (IterVar iv : self->reduce_axis) { + update_state[iv] = 2; + } + for (IterVar iv : self->axis) { + update_state[iv] = 1; + } + // find which iter var is related to reduction and which is related to axis. + schedule::PassDownBitMaskOr(stage, &update_state); + auto leaf_iter_vars = stage->leaf_iter_vars; + // first first loop that is related to reduction. + size_t begin_loop = leaf_iter_vars.size(); + for (size_t i = 0; i < leaf_iter_vars.size(); ++i) { + auto iv = leaf_iter_vars[i]; + int flag = update_state.at(iv); + if ((flag & 2) != 0) { + begin_loop = i; break; + } + ret.init_vmap[iv] = ret.main_vmap.at(iv); + } + ret.num_common_loop = begin_loop; + // skip loops that does not relates to axis. + std::unordered_set skip_iter; + for (auto kv : update_state) { + int flag = kv.second; + if ((flag & 1) == 0) skip_iter.insert(kv.first); + } + ret.init_nest = op::MakeLoopNest( + stage, dom_map, begin_loop, true, + skip_iter, &(ret.init_vmap), debug_keep_trivial_loop); + ret.init_predicates = schedule::MakeBoundCheck( + stage, dom_map, ret.init_vmap, true, skip_iter); + for (auto& e : ret.init_predicates) { + e = likely(e); + } + } else { + CHECK_EQ(ret.main_nest.size(), stage->leaf_iter_vars.size() + 1); + ret.num_common_loop = stage->leaf_iter_vars.size(); + } + // copy elison here. + return ret; +} + + +Stmt MakeTensorOp(const TensorOpNode* self, + const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) { + std::unordered_map out_dom; + std::unordered_map > in_region; + + // Start bind data. + Stmt nop = Evaluate::make(0); + std::vector input_bind_nest, output_bind_nest; + Array inputs = self->InputTensors(); + + // input binding + size_t num_inputs = inputs.size(); + for (size_t i = 0; i < num_inputs; ++i) { + Tensor tensor = inputs[i]; + Region region = self->input_regions[i]; + Buffer buffer = self->intrin->buffers[i]; + Array bind_spec{buffer, tensor}; + + Array tuple; + for (size_t i = 0; i < region.size(); ++i) { + tuple.push_back(region[i]->min); + tuple.push_back(region[i]->extent); + } + input_bind_nest.emplace_back(AttrStmt::make( + bind_spec, ir::attr::buffer_bind_scope, + Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop)); + } + + // output binding + for (int i = 0; i < self->num_outputs(); ++i) { + Tensor tensor = stage->op.output(i); + Buffer buffer = self->intrin->buffers[num_inputs + i]; + Array bind_spec{buffer, tensor}; + + Array tuple; + for (const IterVar ivar : self->axis) { + tuple.push_back(ivar->var); + tuple.push_back(1); + } + for (const Expr extent : buffer->shape) { + tuple.push_back(0); + tuple.push_back(extent); + } + + output_bind_nest.emplace_back(AttrStmt::make( + bind_spec, ir::attr::buffer_bind_scope, + Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop)); + } + + // Check variable remap + std::unordered_map vmap; + ir::ArgBinder binder(&vmap); + + size_t tloc = stage->leaf_iter_vars.size(); + ComputeLoopNest n = MakeLoopNest(self, stage, dom_map, debug_keep_trivial_loop); + + if (self->reduce_axis.size() == 0) { + std::vector > nest( + n.main_nest.begin(), n.main_nest.begin() + tloc + 1); + nest.emplace_back(op::MakeIfNest(n.main_predicates)); + CHECK_EQ(n.init_predicates.size(), 0U); + CHECK(self->intrin->body.defined()) + << "Normal store op for intrin " << self << " is not defined"; + Stmt body = MergeNest(output_bind_nest, self->intrin->body); + body = MergeNest(input_bind_nest, body); + body = ir::Substitute(body, vmap); + body = MergeNest(binder.asserts(), body); + body = op::Substitute(body, n.main_vmap); + Stmt ret = MergeNest(nest, body); + return ret; + } else { + // Need to split reduction + CHECK(self->intrin->reduce_update.defined()) + << "Reduction update op is not defined"; + // Need init and update steps + CHECK_NE(self->reduce_axis.size(), 0U); + std::vector > common( + n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1); + std::vector > update_nest( + n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1); + update_nest.emplace_back(op::MakeIfNest(n.main_predicates)); + + if (self->intrin->reduce_init.defined()) { + // init nest + std::vector > init_nest( + n.init_nest.begin(), n.init_nest.begin() + tloc + 1); + init_nest.emplace_back(op::MakeIfNest(n.init_predicates)); + Stmt init = MergeNest(output_bind_nest, self->intrin->reduce_init); + init = op::Substitute(init, n.init_vmap); + init = MergeNest(init_nest, init); + // The update + Stmt update = MergeNest(output_bind_nest, self->intrin->reduce_update); + update = MergeNest(input_bind_nest, update); + update = ir::Substitute(update, vmap); + update = MergeNest(binder.asserts(), update); + update = op::Substitute(update, n.main_vmap); + update = MergeNest(update_nest, update); + return MergeNest(common, Block::make(init, update)); + } else { + // When init op is not available, use body op for reset in the first iter. + CHECK(self->intrin->body.defined()) + << "Normal body op is not defined"; + Stmt update = TransformUpdate(stage, dom_map, n, + self->intrin->body, + self->intrin->reduce_update); + update = MergeNest(output_bind_nest, update); + update = MergeNest(input_bind_nest, update); + update = ir::Substitute(update, vmap); + update = MergeNest(binder.asserts(), update); + update = op::Substitute(update, n.main_vmap); + update = MergeNest(update_nest, update); + return MergeNest(common, update); + } + } +} + + +Stmt TensorOpNode::BuildProvide( + const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const { + CHECK_EQ(stage->op.operator->(), this); + Stmt ret = MakeTensorOp(this, stage, dom_map, debug_keep_trivial_loop); + return ret; +} + +} // namespace tvm diff --git a/src/op/tensorize.cc b/src/op/tensorize.cc index 6423c4e942e4..7f3634e3635c 100644 --- a/src/op/tensorize.cc +++ b/src/op/tensorize.cc @@ -10,7 +10,6 @@ #include "op_util.h" #include "compute_op.h" #include "../schedule/message_passing.h" -#include "../arithmetic/compute_expr.h" namespace tvm { @@ -323,50 +322,6 @@ void VerifyTensorizeBody( } } -/*! - * \brief Transform the update part when there is no init func in tensorizing - * \param stage The stage for tensorizing. - * \param dom_map The range of each iter var. - * \param n The loop nest structured used in compute. - * \param body The body func in tensorize intrin - * \param update The update func in tensorize intrin - * \return Transformed result. - */ -Stmt TransformUpdate(const Stage& stage, - const std::unordered_map& dom_map, - const ComputeLoopNest& n, - Stmt body, - Stmt update) { - Array conds; - std::unordered_set banned; - for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { - IterVar iv = stage->leaf_iter_vars[i]; - auto iit = stage->iter_var_attrs.find(iv); - if (iit != stage->iter_var_attrs.end()) { - const IterVarAttr& attr = (*iit).second; - if (attr->iter_type == kTensorized) { - break; - } - } - if (iv->iter_type == kCommReduce) { - auto vit = dom_map.find(iv); - CHECK(vit != dom_map.end()); - const Range& vrange = vit->second; - conds.push_back(likely(iv->var > vrange->min)); - banned.insert(iv->var.get()); - } - } - for (const Expr& pred : n.main_predicates) { - if (ir::ExprUseVar(pred, banned)) { - LOG(FATAL) << "Tensorize update transform failed, the condition " - << pred << " has a conflict with the reset condition"; - } - } - - return IfThenElse::make(arith::ComputeReduce(conds, const_true(1)), - update, body); -} - Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, const std::unordered_map& dom_map, diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 8591c77bd7cc..b323ab48b505 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -135,6 +135,151 @@ Tensor Schedule::cache_read(const Tensor& tensor, return cache; } + +// for tensor_op +Array CacheWriteWithReLayoutForTensorOp(Schedule sch, + const Array& tensor_array, + const std::string& scope) { + size_t tensor_size = tensor_array.size(); + sch->InvalidateCache(); + Tensor tensor = tensor_array[0]; + Stage orig_stage = sch[tensor->op]; + const TensorOpNode* tensor_op = orig_stage->op.as(); + std::unordered_set red_axis; + for (IterVar iv : tensor_op->reduce_axis) { + red_axis.insert(iv); + } + std::unordered_map dom_map; + Array new_axis; + + for (IterVar iv : tensor_op->root_iter_vars()) { + dom_map[iv] = iv->dom; + } + schedule::PassDownDomain(orig_stage, &dom_map, true); + std::unordered_map vsub; + std::unordered_map vsub2newvar; + std::vector predicates; + { + // The source->cache + std::unordered_map value_map; + for (IterVar iv : orig_stage->leaf_iter_vars) { + if (red_axis.count(iv)) continue; + CHECK_EQ(iv->iter_type, kDataPar) + << "Can only relayout with in data parallel dimensions"; + Range dom = dom_map.at(iv); + IterVar new_iv = IterVarNode::make( + dom, iv->var.copy_with_suffix(".c"), iv->iter_type); + new_axis.push_back(new_iv); + if (is_one(dom->min)) { + value_map[iv] = dom->extent; + } else { + value_map[iv] = iv->var; + vsub2newvar[iv->var.get()] = new_iv->var; + } + } + // skip reduction iteration. + std::unordered_set skip_bound_check; + for (IterVar iv : tensor_op->reduce_axis) { + skip_bound_check.insert(iv); + } + schedule::PassUpIndex(orig_stage, dom_map, &value_map, true); + predicates = schedule::MakeBoundCheck( + orig_stage, dom_map, value_map, true, skip_bound_check); + // The root axis + for (IterVar iv : tensor_op->axis) { + vsub[iv->var.get()] = value_map.at(iv); + } + } + + Array new_regions; + for (Region old_region : tensor_op->input_regions) { + Region region; + for (Range r : old_region) { + Expr min = VarReplacer(vsub2newvar).Mutate(r->min); + Expr extent = VarReplacer(vsub2newvar).Mutate(r->extent); + region.push_back(Range::make_by_min_extent(min, extent)); + } + new_regions.push_back(region); + } + + Operation cache_op = TensorOpNode::make( + tensor_op->name + "." + scope, tensor_op->tag, new_axis, + tensor_op->reduce_axis, tensor_op->inputs, new_regions, tensor_op->intrin); + + // axis will be used in generating compute op + Array compute_axis = tensor_op->axis; + for (IterVar iv : tensor_op->tensor_axis) { + // new tensor axis with kDataPar IterVar type + IterVar aiv = IterVarNode::make(iv->dom, iv->var, kDataPar); + compute_axis.push_back(aiv); + } + + // The reader args + Array args; + { + // cache->compute + std::unordered_map value_map; + for (IterVar iv : compute_axis) { + value_map[iv] = iv->var; + } + schedule::PassDownIndex(orig_stage, dom_map, &value_map, true); + for (IterVar iv : orig_stage->leaf_iter_vars) { + if (red_axis.count(iv)) continue; + args.push_back(value_map.at(iv)); + } + // tensorized region axis + for (size_t i = 0; i < tensor_op->tensor_axis.size(); ++i) { + IterVar iv = compute_axis[tensor_op->axis.size() + i]; + args.push_back(value_map.at(iv)); + } + } + + Array cache_tensor_list; + Array cache_expr_list; + for (size_t i = 0; i < tensor_size; i++) { + Tensor cache_tensor = cache_op.output(i); + cache_tensor_list.push_back(cache_tensor); + cache_expr_list.push_back(cache_tensor(args)); + } + + Operation orig_new_op = ComputeOpNode::make( + tensor_op->name, tensor_op->tag, {}, + compute_axis, cache_expr_list); + + + // The replace of the dataflow + std::unordered_map vmap; + std::unordered_map rvmap; + vmap[orig_stage->op.output(0)] = orig_new_op.output(0); + rvmap[orig_new_op.output(0)] = orig_stage->op.output(0); + for (size_t i = 0; i < tensor_size; i++) { + vmap[orig_stage->op.output(0)] = orig_new_op.output(0); + rvmap[orig_new_op.output(0)] = orig_stage->op.output(0); + } + ReplaceDataFlow(sch->stages, &vmap, &rvmap); + // mutate orig stage + orig_stage->op = orig_new_op; + orig_stage->all_iter_vars = orig_stage->op->root_iter_vars(); + orig_stage->leaf_iter_vars = orig_stage->all_iter_vars; + orig_stage->relations = Array(); + // create schedule for new cached stage. + ArrayNode* stages = sch->stages.CopyOnWrite(); + size_t pos = FindNodeRef(stages, orig_stage); + Stage cache_stage = Stage(cache_op); + cache_stage.set_scope(scope); + CHECK_LT(pos, stages->data.size()); + stages->data.insert(stages->data.begin() + pos, + cache_stage.node_); + sch->stage_map.Set(cache_op, cache_stage); + // Update group + cache_stage->group = orig_stage->group; + if (cache_stage->group.defined()) { + ++cache_stage->group->num_child_stages; + } + return cache_tensor_list; +} + + // Cache write and relayout the data according to loop pattern Array CacheWriteWithReLayout(Schedule sch, const Array& tensor_array, @@ -143,6 +288,9 @@ Array CacheWriteWithReLayout(Schedule sch, sch->InvalidateCache(); Tensor tensor = tensor_array[0]; Stage orig_stage = sch[tensor->op]; + if (!strcmp(orig_stage->op->type_key(), "TensorOp")) { + return CacheWriteWithReLayoutForTensorOp(sch, tensor_array, scope); + } const ComputeOpNode* compute = orig_stage->op.as(); std::unordered_set red_axis; for (IterVar iv : compute->reduce_axis) { @@ -198,7 +346,7 @@ Array CacheWriteWithReLayout(Schedule sch, body = InjectPredicate(predicates, body); body = VarReplacer(vsub2newvar).Mutate(body); // Reduce nodes in ONE computeOp must be the same except value_index - // This is right only if the oringinal body ensures Reduce nodes are the same + // This is right only if the original body ensures Reduce nodes are the same if (body->is_type()) { const ir::Reduce* reduce_body = body.as(); if (first_reduce != nullptr) { @@ -276,6 +424,7 @@ Array CacheWriteWithReLayout(Schedule sch, return cache_tensor_list; } + Array Schedule::cache_write(const Array& tensor_array, const std::string& scope) { (*this)->InvalidateCache(); @@ -299,11 +448,11 @@ Tensor Schedule::cache_write(const Tensor& tensor, const std::string& scope) { (*this)->InvalidateCache(); Stage orig_stage = operator[](tensor->op); - const ComputeOpNode* compute = tensor->op.as(); - CHECK(compute) - << "cache write only take ComputeOp as writers"; - CHECK_EQ(compute->num_outputs(), 1) - << "cache write only support single output ComputeOp"; + // const ComputeOpNode* compute = tensor->op.as(); + // CHECK(compute) + // << "cache write only take ComputeOp as writers"; + // CHECK_EQ(compute->num_outputs(), 1) + // << "cache write only support single output ComputeOp"; return (CacheWriteWithReLayout(*this, {tensor}, scope))[0]; } diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py index 1d8603dfc98b..454ca624759e 100644 --- a/tests/python/unittest/test_lang_tensor.py +++ b/tests/python/unittest/test_lang_tensor.py @@ -85,6 +85,37 @@ def test_tensor_reduce(): assert(str(C_loaded) == str(C)) +def test_tensor_tensor_op(): + M = 1024 + factor = 16 + + def intrin_vadd(n): + x = tvm.placeholder((n,)) + y = tvm.placeholder((n,)) + z = tvm.compute(x.shape, lambda i: x[i] + y[i]) + + def intrin_func(ins, outs): + ib = tvm.ir_builder.create() + ib.emit(tvm.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr'))) + return ib.get() + + with tvm.build_config(offset_factor=n): + return tvm.decl_tensor_intrin(z.op, intrin_func) + + A = tvm.placeholder((M/factor, factor), name="A") + B = tvm.placeholder((M/factor, factor), name="B") + + intrin = intrin_vadd(factor) + C = tvm.tensor_op([M/factor,], [factor,], + lambda i: [A[i, 0:factor], B[i, 0:factor]], + intrin, name='C') + + s = tvm.create_schedule(C.op) + stmt = tvm.lower(s, [A, B, C], simple_mode=True) + intrin_code = stmt.body.body.value + assert isinstance(intrin_code, tvm.expr.Call) + + def test_tensor_scan(): m = tvm.var("m") n = tvm.var("n") @@ -193,6 +224,7 @@ def test_tensor_inputs(): test_tensor_slice() test_tensor() test_tensor_reduce() + test_tensor_tensor_op() test_tensor_scan() test_scan_multi_out() test_extern() From 97b3f653b760c7866d8b4e0df184312ac1ebeda4 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Sun, 16 Sep 2018 08:07:55 -0700 Subject: [PATCH 2/7] [TensorOp] Support for intrin(..) and rename to TensorComputeOp. --- include/tvm/operation.h | 19 ++- include/tvm/tensor_intrin.h | 55 +++++++ python/tvm/api.py | 188 ++++++++++++---------- python/tvm/tensor.py | 7 +- python/tvm/tensor_intrin.py | 7 +- src/api/api_lang.cc | 22 ++- src/lang/tensor.cc | 51 +++++- src/op/tensor_op.cc | 91 ++++++----- src/schedule/schedule_dataflow_rewrite.cc | 13 +- tests/python/unittest/test_lang_tensor.py | 49 +++--- 10 files changed, 319 insertions(+), 183 deletions(-) diff --git a/include/tvm/operation.h b/include/tvm/operation.h index 5eeac93dfc94..f87083303903 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -182,7 +182,7 @@ class PlaceholderOpNode : public OperationNode { TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode, OperationNode); }; -class TensorOpNode : public OperationNode { +class TensorComputeOpNode : public OperationNode { public: Array axis; @@ -197,7 +197,7 @@ class TensorOpNode : public OperationNode { TensorIntrin intrin; /*! \brief constructor */ - TensorOpNode() {} + TensorComputeOpNode() {} // override functions int num_outputs() const final; @@ -237,13 +237,20 @@ class TensorOpNode : public OperationNode { static Operation make(std::string name, std::string tag, Array axis, + Array tensor_axis, + TensorIntrinCall intrin_call); + + static Operation make(std::string name, + std::string tag, + Array axis, + Array tensor_axis, Array reduce_axis, - Array inputs, - Array input_regions, + Array tensors, + Array regions, TensorIntrin intrin); - static constexpr const char* _type_key = "TensorOp"; - TVM_DECLARE_NODE_TYPE_INFO(TensorOpNode, OperationNode); + static constexpr const char* _type_key = "TensorComputeOp"; + TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, OperationNode); }; /*! diff --git a/include/tvm/tensor_intrin.h b/include/tvm/tensor_intrin.h index 944498d1e615..1cbd52dc4d0b 100644 --- a/include/tvm/tensor_intrin.h +++ b/include/tvm/tensor_intrin.h @@ -26,6 +26,14 @@ class TensorIntrin : public NodeRef { */ inline const TensorIntrinNode* operator->() const; + // template + // inline Stmt operator()(Args&& ...args) const { + // Array inputs{std::forward(args)...}; + // return operator()(inputs); + // } + + // TVM_DLL TensorIntrinCall operator()(Array inputs) const; + /*! \brief specify container node */ using ContainerType = TensorIntrinNode; }; @@ -89,5 +97,52 @@ class TensorIntrinNode : public Node { inline const TensorIntrinNode* TensorIntrin::operator->() const { return static_cast(node_.get()); } + + +// Internal node container of tensor intrinsics. +class TensorIntrinCallNode; + +/*! \brief Tensor intrinsic node. */ +class TensorIntrinCall : public NodeRef { + public: + TensorIntrinCall() {} + explicit TensorIntrinCall(std::shared_ptr n) : NodeRef(n) {} + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const TensorIntrinCallNode* operator->() const; + + /*! \brief specify container node */ + using ContainerType = TensorIntrinCallNode; +}; + +class TensorIntrinCallNode : public Node { + public: + TensorIntrin intrin; + Array tensors; + Array regions; + Array reduce_axis; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("intrin", &intrin); + v->Visit("tensors", &tensors); + v->Visit("regions", ®ions); + v->Visit("reduce_axis", &reduce_axis); + } + + static TensorIntrinCall make(TensorIntrin intrin, + Array tensors, + Array regions, + Array reduce_axis); + + static constexpr const char* _type_key = "TensorIntrinCall"; + TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinCallNode, Node); +}; + +inline const TensorIntrinCallNode* TensorIntrinCall::operator->() const { + return static_cast(node_.get()); +} + } // namespace tvm #endif // TVM_TENSOR_INTRIN_H_ diff --git a/python/tvm/api.py b/python/tvm/api.py index 1c1dbfaa1532..4d6740403224 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -248,21 +248,38 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None): ndim = len(shape) code = fcompute.__code__ - if fcompute.__code__.co_argcount == 0: + out_ndim = ndim + if code.co_argcount == 0: arg_names = ["i%d" % i for i in range(ndim)] else: arg_names = code.co_varnames[:code.co_argcount] + out_ndim = code.co_argcount - if ndim != len(arg_names): + # TODO check ndim, arg_names + if out_ndim != len(arg_names): raise ValueError("fcompute do not match dimension, ndim=%d" % ndim) - dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape)] + dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape[:out_ndim])] body = fcompute(*[v.var for v in dim_var]) - if not isinstance(body, (list, tuple)): - body = [body] - body = convert(body) - op_node = _api_internal._ComputeOp( - name, tag, attrs, dim_var, body) + + if isinstance(body, _tensor.TensorIntrinCall): + tensor_var = [] + for i, s in enumerate(shape[out_ndim:]): + name = "ax" + str(i) + tensor_var.append(_IterVar((0, s), name, 4)) + op_node = _api_internal._TensorComputeOp(name, + tag, + dim_var, + tensor_var, + body) + else: + if not isinstance(body, (list, tuple)): + body = [body] + body = convert(body) + # print('body: {0}'.format(body)) + op_node = _api_internal._ComputeOp( + name, tag, attrs, dim_var, body) + num = op_node.num_outputs outputs = tuple(op_node.output(i) for i in range(num)) return outputs[0] if num == 1 else outputs @@ -337,84 +354,85 @@ def scan(init, update, state_placeholder, inputs=None, name="scan", tag="", attr return res[0] if len(res) == 1 else res -def tensor_op(out_dims, - in_dims, # pylint: disable=unused-argument - finputs, - intrin, - raxis=None, - name='tensor_op', - tag=""): - """Construct new tensors with intrinsic. - - Parameters - ---------- - out_dims: tuple - The dimensions out of the tensorized region, which can be - scheduled through `reorder`, `split`. - - in_dims: tuple - The dimensions inside of the tensorized region, which cannot - be manipulated. - - finputs: lambda function of out_dims -> list of TensorSlice - Specifies involved regions of input tensors. - - tensor_intrin : TensorIntrin - The tensor intrinsic used for computation. - - raxis : IterVar - An iteration variable representing the value. - - name: str, optional - The name hint of the tensor - - tag: str, optional - Additonal tag information about the compute. - """ - def _get_region(tslice): - region = [] - for idx in tslice.indices: - if isinstance(idx, slice): - assert idx.step is None - region.append(Range(idx.start, idx.stop)) +def _get_region(tslice): + region = [] + for idx in tslice.indices: + if isinstance(idx, slice): + assert idx.step is None + region.append(Range(idx.start, idx.stop)) + else: + if isinstance(idx, _schedule.IterVar): + begin = idx.var else: - if isinstance(idx, _schedule.IterVar): - begin = idx.var - else: - begin = idx - region.append(_make.range_by_min_extent(begin, 1)) - return region - - if _tag.TagScope.current is not None: - if tag != "": - raise ValueError("nested tag is not allowed for now") - tag = _tag.TagScope.current.tag - - code = finputs.__code__ - if finputs.__code__.co_argcount == 0: - arg_names = ["i%d" % i for i in range(ndim)] - else: - arg_names = code.co_varnames[:code.co_argcount] - - if len(out_dims) != len(arg_names): - raise ValueError("finputs do not match dimension, ndim=%d" % out_dims) - - out_var = [_IterVar((0, extent), arg_name, 0) for arg_name, extent in zip(arg_names, out_dims)] - if isinstance(raxis, _schedule.IterVar): - raxis = [raxis] - if raxis is None: - raxis = [] - tensor_regions = finputs(*[v.var for v in out_var]) - - op = _api_internal._TensorOp(name, - tag, - out_var, - raxis, - [x.tensor for x in tensor_regions], - [_get_region(x) for x in tensor_regions], - intrin) - # only support single output - return op.output(0) + begin = idx + region.append(_make.range_by_min_extent(begin, 1)) + return region + + +# def tensor_op(out_dims, +# in_dims, # pylint: disable=unused-argument +# finputs, +# intrin, +# raxis=None, +# name='tensor_op', +# tag=""): +# """Construct new tensors with intrinsic. +# +# Parameters +# ---------- +# out_dims: tuple +# The dimensions out of the tensorized region, which can be +# scheduled through `reorder`, `split`. +# +# in_dims: tuple +# The dimensions inside of the tensorized region, which cannot +# be manipulated. +# +# finputs: lambda function of out_dims -> list of TensorSlice +# Specifies involved regions of input tensors. +# +# tensor_intrin : TensorIntrin +# The tensor intrinsic used for computation. +# +# raxis : IterVar +# An iteration variable representing the value. +# +# name: str, optional +# The name hint of the tensor +# +# tag: str, optional +# Additonal tag information about the compute. +# """ +# if _tag.TagScope.current is not None: +# if tag != "": +# raise ValueError("nested tag is not allowed for now") +# tag = _tag.TagScope.current.tag +# +# code = finputs.__code__ +# if finputs.__code__.co_argcount == 0: +# arg_names = ["i%d" % i for i in range(ndim)] +# else: +# arg_names = code.co_varnames[:code.co_argcount] +# +# if len(out_dims) != len(arg_names): +# raise ValueError("finputs do not match dimension, ndim=%d" % out_dims) +# +# out_var = [_IterVar((0, extent), arg_name, 0) for arg_name, extent in zip(arg_names, out_dims)] +# if isinstance(raxis, _schedule.IterVar): +# raxis = [raxis] +# if raxis is None: +# raxis = [] +# tensor_regions = finputs(*[v.var for v in out_var]) +# +# op = _api_internal._TensorOp(name, +# tag, +# out_var, +# raxis, +# [x.tensor for x in tensor_regions], +# [_get_region(x) for x in tensor_regions], +# intrin) +# # only support single output +# return op.output(0) def extern(shape, diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py index 254e17710d47..b20a68ac0800 100644 --- a/python/tvm/tensor.py +++ b/python/tvm/tensor.py @@ -30,6 +30,11 @@ def dtype(self): """Data content of the tensor.""" return self.tensor.dtype +@register_node +class TensorIntrinCall(NodeBase): + """Intermediate structure for calling a tensor intrinsic.""" + pass + itervar_cls = None @@ -172,6 +177,6 @@ class ExternOp(Operation): @register_node -class TensorOp(Operation): +class TensorComputeOp(Operation): """Tensor operation.""" pass diff --git a/python/tvm/tensor_intrin.py b/python/tvm/tensor_intrin.py index 62f8c8897d10..a9b45c772a7e 100644 --- a/python/tvm/tensor_intrin.py +++ b/python/tvm/tensor_intrin.py @@ -17,8 +17,11 @@ class TensorIntrin(NodeBase): -------- decl_tensor_intrin: Construct a TensorIntrin """ - pass - + def __call__(self, *args): + tensors = [x.tensor for x in args] + regions = [_api._get_region(x) for x in args] + # TODO + return _api_internal._TensorIntrinCall(self, tensors, regions, []) def decl_tensor_intrin(op, fcompute, diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 8cc34dd74448..67377e468a4f 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -239,6 +239,14 @@ TVM_REGISTER_API("_TensorIntrin") args[6]); }); +TVM_REGISTER_API("_TensorIntrinCall") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = TensorIntrinCallNode::make(args[0], + args[1], + args[2], + args[3]); + }); + TVM_REGISTER_API("_TensorEqual") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = args[0].operator Tensor() == args[1].operator Tensor(); @@ -278,15 +286,13 @@ TVM_REGISTER_API("_ScanOp") args[7]); }); -TVM_REGISTER_API("_TensorOp") +TVM_REGISTER_API("_TensorComputeOp") .set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = TensorOpNode::make(args[0], - args[1], - args[2], - args[3], - args[4], - args[5], - args[6]); + *ret = TensorComputeOpNode::make(args[0], + args[1], + args[2], + args[3], + args[4]); }); TVM_REGISTER_API("_ExternOp") diff --git a/src/lang/tensor.cc b/src/lang/tensor.cc index 4f9c3e9d1782..57e431a4e123 100644 --- a/src/lang/tensor.cc +++ b/src/lang/tensor.cc @@ -10,6 +10,8 @@ namespace tvm { +// Tensor + Expr Tensor::operator()(Array indices) const { Array arr(indices.begin(), indices.end()); return operator()(arr); @@ -26,6 +28,21 @@ Expr Tensor::operator()(Array indices) const { return n; } +Tensor Operation::output(size_t i) const { + auto node = std::make_shared(); + node->op = *this; + node->value_index = i; + node->dtype = (*this)->output_dtype(i); + node->shape = (*this)->output_shape(i); + return Tensor(node); +} + +// TensorIntrinCall TensorIntrin::operator()(Array inputs) const { +// using HalideIR::Internal::Call; +// LOG(FATAL) << "CallTensorIntrin"; +// CHECK_EQ(tensors.size(), regions.size()); +// } + Tensor TensorNode::make(Array shape, Type dtype, Operation op, @@ -46,14 +63,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(TensorNode); -Tensor Operation::output(size_t i) const { - auto node = make_node(); - node->op = *this; - node->value_index = i; - node->dtype = (*this)->output_dtype(i); - node->shape = (*this)->output_shape(i); - return Tensor(node); -} + +// TensorIntrin TensorIntrin TensorIntrinNode::make(std::string name, Operation op, @@ -79,4 +90,28 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); TVM_REGISTER_NODE_TYPE(TensorIntrinNode); + + +// TensorIntrinCall + +TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin, + Array tensors, + Array regions, + Array reduce_axis) { + auto n = std::make_shared(); + LOG(INFO) << "TensorIntrinCallNode make"; + n->intrin = std::move(intrin); + n->tensors = std::move(tensors); + n->regions = std::move(regions); + n->reduce_axis = std::move(reduce_axis); + return TensorIntrinCall(n); +} + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const TensorIntrinCallNode *n, IRPrinter *p) { + p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")"; + }); + +TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode); + } // namespace tvm diff --git a/src/op/tensor_op.cc b/src/op/tensor_op.cc index 13ca38b63529..e609c3c2393e 100644 --- a/src/op/tensor_op.cc +++ b/src/op/tensor_op.cc @@ -15,19 +15,20 @@ namespace tvm { using namespace ir; -// TensorOpNode +// TensorComputeOpNode TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const TensorOpNode *op, IRPrinter *p) { - p->stream << "tensor_op(" << op->name << ", " << op << ")"; +.set_dispatch([](const TensorComputeOpNode *op, + IRPrinter *p) { + p->stream << "tensor_compute_op(" << op->name << ", " << op << ")"; }); -TVM_REGISTER_NODE_TYPE(TensorOpNode); +TVM_REGISTER_NODE_TYPE(TensorComputeOpNode); -int TensorOpNode::num_outputs() const { +int TensorComputeOpNode::num_outputs() const { return static_cast(this->intrin->buffers.size() - this->inputs.size()); } -Array TensorOpNode::root_iter_vars() const { +Array TensorComputeOpNode::root_iter_vars() const { Array ret = axis; for (IterVar iv : tensor_axis) { ret.push_back(iv); @@ -38,11 +39,11 @@ Array TensorOpNode::root_iter_vars() const { return ret; } -Type TensorOpNode::output_dtype(size_t i) const { +Type TensorComputeOpNode::output_dtype(size_t i) const { return this->intrin->buffers[this->inputs.size() + i]->dtype; } -Array TensorOpNode::output_shape(size_t i) const { +Array TensorComputeOpNode::output_shape(size_t i) const { Array shape; for (const auto& ivar : this->axis) { shape.push_back(ivar->dom->extent); @@ -55,41 +56,50 @@ Array TensorOpNode::output_shape(size_t i) const { } -Operation TensorOpNode::make(std::string name, - std::string tag, - Array axis, - Array reduce_axis, - Array inputs, - Array input_regions, - TensorIntrin intrin) { - auto n = std::make_shared(); +Operation TensorComputeOpNode::make(std::string name, + std::string tag, + Array axis, + Array tensor_axis, + TensorIntrinCall intrin_call) { + return TensorComputeOpNode::make(name, + tag, + axis, + tensor_axis, + intrin_call->reduce_axis, + intrin_call->tensors, + intrin_call->regions, + intrin_call->intrin); +} + +Operation TensorComputeOpNode::make(std::string name, + std::string tag, + Array axis, + Array tensor_axis, + Array reduce_axis, + Array tensors, + Array regions, + TensorIntrin intrin) { + auto n = std::make_shared(); n->name = name; n->tag = tag; n->axis = axis; - - Array tout_shape = intrin->buffers[inputs.size()]->shape; - for (size_t i = 0; i < tout_shape.size(); ++i) { - Var var("ax" + std::to_string(i)); - Range range = Range::make_by_min_extent(make_zero(Int(32)), tout_shape[i]); - n->tensor_axis.push_back(IterVarNode::make(range, var, kOpaque)); - } - + n->tensor_axis = tensor_axis; n->reduce_axis = reduce_axis; - n->inputs = inputs; - n->input_regions = input_regions; + n->inputs = tensors; + n->input_regions = regions; n->intrin = intrin; return Operation(n); } -Array TensorOpNode::InputTensors() const { +Array TensorComputeOpNode::InputTensors() const { return inputs; } -Operation TensorOpNode::ReplaceInputs( +Operation TensorComputeOpNode::ReplaceInputs( const Operation& self, const std::unordered_map& rmap) const { CHECK_EQ(self.operator->(), this); - auto n = std::make_shared(*this); + auto n = std::make_shared(*this); auto intrin = std::make_shared(*(this->intrin.operator->())); intrin->body = op::ReplaceTensor(this->intrin->body, rmap); if (intrin->reduce_init.defined()) { @@ -116,7 +126,7 @@ Operation TensorOpNode::ReplaceInputs( } } -void TensorOpNode::PropBoundToInputs( +void TensorComputeOpNode::PropBoundToInputs( const Operation& self, const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const { @@ -133,7 +143,7 @@ void TensorOpNode::PropBoundToInputs( } } -void TensorOpNode::GatherBound( +void TensorComputeOpNode::GatherBound( const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const { @@ -149,7 +159,7 @@ void TensorOpNode::GatherBound( } } -Stmt TensorOpNode::BuildRealize( +Stmt TensorComputeOpNode::BuildRealize( const Stage& stage, const std::unordered_map& realize_map, const Stmt& body) const { @@ -188,7 +198,7 @@ Stmt TensorOpNode::BuildRealize( } ComputeLoopNest MakeLoopNest( - const TensorOpNode* self, + const TensorComputeOpNode* self, const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) { @@ -254,13 +264,10 @@ ComputeLoopNest MakeLoopNest( } -Stmt MakeTensorOp(const TensorOpNode* self, - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) { - std::unordered_map out_dom; - std::unordered_map > in_region; - +Stmt MakeTensorComputeOp(const TensorComputeOpNode* self, + const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) { // Start bind data. Stmt nop = Evaluate::make(0); std::vector input_bind_nest, output_bind_nest; @@ -373,12 +380,12 @@ Stmt MakeTensorOp(const TensorOpNode* self, } -Stmt TensorOpNode::BuildProvide( +Stmt TensorComputeOpNode::BuildProvide( const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const { CHECK_EQ(stage->op.operator->(), this); - Stmt ret = MakeTensorOp(this, stage, dom_map, debug_keep_trivial_loop); + Stmt ret = MakeTensorComputeOp(this, stage, dom_map, debug_keep_trivial_loop); return ret; } diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index b323ab48b505..02704fb56124 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -137,14 +137,14 @@ Tensor Schedule::cache_read(const Tensor& tensor, // for tensor_op -Array CacheWriteWithReLayoutForTensorOp(Schedule sch, +Array CacheWriteWithReLayoutForTensorComputeOp(Schedule sch, const Array& tensor_array, const std::string& scope) { size_t tensor_size = tensor_array.size(); sch->InvalidateCache(); Tensor tensor = tensor_array[0]; Stage orig_stage = sch[tensor->op]; - const TensorOpNode* tensor_op = orig_stage->op.as(); + const TensorComputeOpNode* tensor_op = orig_stage->op.as(); std::unordered_set red_axis; for (IterVar iv : tensor_op->reduce_axis) { red_axis.insert(iv); @@ -202,9 +202,10 @@ Array CacheWriteWithReLayoutForTensorOp(Schedule sch, new_regions.push_back(region); } - Operation cache_op = TensorOpNode::make( + Operation cache_op = TensorComputeOpNode::make( tensor_op->name + "." + scope, tensor_op->tag, new_axis, - tensor_op->reduce_axis, tensor_op->inputs, new_regions, tensor_op->intrin); + tensor_op->reduce_axis, tensor_op->tensor_axis, + tensor_op->inputs, new_regions, tensor_op->intrin); // axis will be used in generating compute op Array compute_axis = tensor_op->axis; @@ -288,8 +289,8 @@ Array CacheWriteWithReLayout(Schedule sch, sch->InvalidateCache(); Tensor tensor = tensor_array[0]; Stage orig_stage = sch[tensor->op]; - if (!strcmp(orig_stage->op->type_key(), "TensorOp")) { - return CacheWriteWithReLayoutForTensorOp(sch, tensor_array, scope); + if (!strcmp(orig_stage->op->type_key(), "TensorComputeOp")) { + return CacheWriteWithReLayoutForTensorComputeOp(sch, tensor_array, scope); } const ComputeOpNode* compute = orig_stage->op.as(); std::unordered_set red_axis; diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py index 454ca624759e..d3bd1f16d9a3 100644 --- a/tests/python/unittest/test_lang_tensor.py +++ b/tests/python/unittest/test_lang_tensor.py @@ -84,10 +84,10 @@ def test_tensor_reduce(): assert(isinstance(C_loaded, tvm.tensor.Tensor)) assert(str(C_loaded) == str(C)) - -def test_tensor_tensor_op(): - M = 1024 +def test_tensor_region(): + m = 1024 factor = 16 + dtype = 'float32' def intrin_vadd(n): x = tvm.placeholder((n,)) @@ -102,18 +102,16 @@ def intrin_func(ins, outs): with tvm.build_config(offset_factor=n): return tvm.decl_tensor_intrin(z.op, intrin_func) - A = tvm.placeholder((M/factor, factor), name="A") - B = tvm.placeholder((M/factor, factor), name="B") + vadd = intrin_vadd(factor) - intrin = intrin_vadd(factor) - C = tvm.tensor_op([M/factor,], [factor,], - lambda i: [A[i, 0:factor], B[i, 0:factor]], - intrin, name='C') + A = tvm.placeholder((m/factor, factor), name="A", dtype=dtype) + B = tvm.placeholder((m/factor, factor), name="B", dtype=dtype) + C = tvm.compute((m/factor, factor), + lambda i: vadd(A[i, 0:factor], B[i, 0:factor])) s = tvm.create_schedule(C.op) stmt = tvm.lower(s, [A, B, C], simple_mode=True) - intrin_code = stmt.body.body.value - assert isinstance(intrin_code, tvm.expr.Call) + print(stmt) def test_tensor_scan(): @@ -217,17 +215,18 @@ def test_tensor_inputs(): if __name__ == "__main__": - test_rank_zero() - test_tensor_inputs() - test_tensor_reduce_multi_axis() - test_conv1d() - test_tensor_slice() - test_tensor() - test_tensor_reduce() - test_tensor_tensor_op() - test_tensor_scan() - test_scan_multi_out() - test_extern() - test_extern_multi_out() - test_tuple_inputs() - test_tuple_with_different_deps() + # test_rank_zero() + # test_tensor_inputs() + # test_tensor_reduce_multi_axis() + # test_conv1d() + # test_tensor_slice() + # test_tensor() + test_tensor_region() + # test_tensor_reduce() + # test_tensor_tensor_op() + # test_tensor_scan() + # test_scan_multi_out() + # test_extern() + # test_extern_multi_out() + # test_tuple_inputs() + # test_tuple_with_different_deps() From 753cb03fccaeab19459022cd9643ff7ad8c08e86 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Sat, 22 Sep 2018 09:02:53 -0700 Subject: [PATCH 3/7] [TensorOp] Add testcase for reduction. --- include/tvm/tensor_intrin.h | 2 +- python/tvm/api.py | 4 +- python/tvm/tensor_intrin.py | 11 +++- src/lang/tensor.cc | 5 +- src/op/tensor_op.cc | 6 +- src/pass/arg_binder.cc | 4 +- tests/python/unittest/test_lang_tensor.py | 77 ++++++++++++++++++----- 7 files changed, 79 insertions(+), 30 deletions(-) diff --git a/include/tvm/tensor_intrin.h b/include/tvm/tensor_intrin.h index 1cbd52dc4d0b..6e2db89496d1 100644 --- a/include/tvm/tensor_intrin.h +++ b/include/tvm/tensor_intrin.h @@ -106,7 +106,7 @@ class TensorIntrinCallNode; class TensorIntrinCall : public NodeRef { public: TensorIntrinCall() {} - explicit TensorIntrinCall(std::shared_ptr n) : NodeRef(n) {} + explicit TensorIntrinCall(NodePtr n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container diff --git a/python/tvm/api.py b/python/tvm/api.py index 4d6740403224..8f5068e06cfa 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -15,7 +15,6 @@ from . import _api_internal from . import make as _make from . import expr as _expr -from . import stmt as _stmt from . import tensor as _tensor from . import schedule as _schedule from . import container as _container @@ -417,7 +416,8 @@ def _get_region(tslice): # if len(out_dims) != len(arg_names): # raise ValueError("finputs do not match dimension, ndim=%d" % out_dims) # -# out_var = [_IterVar((0, extent), arg_name, 0) for arg_name, extent in zip(arg_names, out_dims)] +# out_var = [_IterVar((0, extent), arg_name, 0) +# for arg_name, extent in zip(arg_names, out_dims)] # if isinstance(raxis, _schedule.IterVar): # raxis = [raxis] # if raxis is None: diff --git a/python/tvm/tensor_intrin.py b/python/tvm/tensor_intrin.py index a9b45c772a7e..818c487d0f35 100644 --- a/python/tvm/tensor_intrin.py +++ b/python/tvm/tensor_intrin.py @@ -17,11 +17,16 @@ class TensorIntrin(NodeBase): -------- decl_tensor_intrin: Construct a TensorIntrin """ - def __call__(self, *args): + def __call__(self, *args, **kwargs): tensors = [x.tensor for x in args] regions = [_api._get_region(x) for x in args] - # TODO - return _api_internal._TensorIntrinCall(self, tensors, regions, []) + reduce_axis = [] + if "reduce_axis" in kwargs: + reduce_axis = kwargs["reduce_axis"] + if not isinstance(reduce_axis, (list, tuple)): + reduce_axis = [reduce_axis] + reduce_axis = _api.convert(reduce_axis) + return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis) def decl_tensor_intrin(op, fcompute, diff --git a/src/lang/tensor.cc b/src/lang/tensor.cc index 57e431a4e123..9544de33a222 100644 --- a/src/lang/tensor.cc +++ b/src/lang/tensor.cc @@ -29,7 +29,7 @@ Expr Tensor::operator()(Array indices) const { } Tensor Operation::output(size_t i) const { - auto node = std::make_shared(); + auto node = make_node(); node->op = *this; node->value_index = i; node->dtype = (*this)->output_dtype(i); @@ -98,8 +98,7 @@ TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin, Array tensors, Array regions, Array reduce_axis) { - auto n = std::make_shared(); - LOG(INFO) << "TensorIntrinCallNode make"; + auto n = make_node(); n->intrin = std::move(intrin); n->tensors = std::move(tensors); n->regions = std::move(regions); diff --git a/src/op/tensor_op.cc b/src/op/tensor_op.cc index e609c3c2393e..fd9122c8b092 100644 --- a/src/op/tensor_op.cc +++ b/src/op/tensor_op.cc @@ -79,7 +79,7 @@ Operation TensorComputeOpNode::make(std::string name, Array tensors, Array regions, TensorIntrin intrin) { - auto n = std::make_shared(); + auto n = make_node(); n->name = name; n->tag = tag; n->axis = axis; @@ -99,8 +99,8 @@ Operation TensorComputeOpNode::ReplaceInputs( const Operation& self, const std::unordered_map& rmap) const { CHECK_EQ(self.operator->(), this); - auto n = std::make_shared(*this); - auto intrin = std::make_shared(*(this->intrin.operator->())); + auto n = make_node(*this); + auto intrin = make_node(*(this->intrin.operator->())); intrin->body = op::ReplaceTensor(this->intrin->body, rmap); if (intrin->reduce_init.defined()) { intrin->reduce_init = op::ReplaceTensor(this->intrin->reduce_init, rmap); diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc index 0fac313c079b..623886c31b86 100644 --- a/src/pass/arg_binder.cc +++ b/src/pass/arg_binder.cc @@ -91,7 +91,9 @@ void ArgBinder::BindBuffer(const Buffer& arg, // bind pointer and offset. if (is_zero(arg->elem_offset)) { CHECK(is_zero(value->elem_offset)) - << "Trying to bind a Buffer with offset into one without offset"; + << "Trying to bind a Buffer with offset into one without offset " + << " required elem_offset=" << arg->elem_offset + << ", provided elem_offset=" << value->elem_offset; } this->Bind(arg->data, value->data, arg_name + ".data"); diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py index d3bd1f16d9a3..aeccf9cddab6 100644 --- a/tests/python/unittest/test_lang_tensor.py +++ b/tests/python/unittest/test_lang_tensor.py @@ -84,7 +84,7 @@ def test_tensor_reduce(): assert(isinstance(C_loaded, tvm.tensor.Tensor)) assert(str(C_loaded) == str(C)) -def test_tensor_region(): +def test_tensor_compute1(): m = 1024 factor = 16 dtype = 'float32' @@ -111,8 +111,51 @@ def intrin_func(ins, outs): s = tvm.create_schedule(C.op) stmt = tvm.lower(s, [A, B, C], simple_mode=True) - print(stmt) + assert isinstance(stmt.body.body, tvm.stmt.Evaluate) +def test_tensor_compute2(): + M = 2048 + N = 1024 + L = 1024 + factor = 16 + factor1 = 32 + factor2 = 32 + dtype = 'float32' + + def intrin_gemm(m, n, l): + k = tvm.reduce_axis((0, l)) + x = tvm.placeholder((m, l)) + y = tvm.placeholder((n, l)) + # in theory, no relation + z = tvm.compute((m, n), lambda i, j: tvm.sum(x[i][k] * y[j][k], axis=k)) + + def intrin_func(ins, outs): + x_ptr = ins[0].access_ptr("r") + y_ptr = ins[1].access_ptr("r") + z_ptr = outs[0].access_ptr("w") + body = tvm.call_packed( + "gemv", x_ptr, y_ptr, z_ptr, m, n, l) + reset = tvm.call_packed( + "fill_zero", z_ptr, m, n) + update = tvm.call_packed( + "gemv_add", x_ptr, y_ptr, z_ptr, m, n, l) + return body, reset, update + + with tvm.build_config(offset_factor=n): + return tvm.decl_tensor_intrin(z.op, intrin_func) + + vgemm = intrin_gemm(factor1, factor2, factor) + + A = tvm.placeholder((M/factor1, L/factor, factor1, factor), name="A", dtype=dtype) + B = tvm.placeholder((N/factor2, L/factor, factor2, factor), name="B", dtype=dtype) + k = tvm.reduce_axis((0, L/factor), name='k') + C = tvm.compute((M/factor1, N/factor2, factor1, factor2), + lambda i, j: vgemm(A[i, k, 0:factor1, 0:factor], B[j, k, 0:factor2, 0:factor], reduce_axis=k)) + + s = tvm.create_schedule(C.op) + stmt = tvm.lower(s, [A, B, C], simple_mode=True) + assert isinstance(stmt.body.body.body.first, tvm.stmt.Evaluate) + assert isinstance(stmt.body.body.body.rest.body, tvm.stmt.Evaluate) def test_tensor_scan(): m = tvm.var("m") @@ -215,18 +258,18 @@ def test_tensor_inputs(): if __name__ == "__main__": - # test_rank_zero() - # test_tensor_inputs() - # test_tensor_reduce_multi_axis() - # test_conv1d() - # test_tensor_slice() - # test_tensor() - test_tensor_region() - # test_tensor_reduce() - # test_tensor_tensor_op() - # test_tensor_scan() - # test_scan_multi_out() - # test_extern() - # test_extern_multi_out() - # test_tuple_inputs() - # test_tuple_with_different_deps() + test_rank_zero() + test_tensor_inputs() + test_tensor_reduce_multi_axis() + test_conv1d() + test_tensor_slice() + test_tensor() + test_tensor_compute1() + test_tensor_compute2() + test_tensor_reduce() + test_tensor_scan() + test_scan_multi_out() + test_extern() + test_extern_multi_out() + test_tuple_inputs() + test_tuple_with_different_deps() From cf0134df05624e65d15b51f4d9f81c8f12e8c180 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Sun, 23 Sep 2018 11:48:17 -0700 Subject: [PATCH 4/7] [TensorOp] Add testcase for scheduling tensor_compute_op. --- include/tvm/operation.h | 7 +- include/tvm/tensor_intrin.h | 8 - python/tvm/api.py | 91 +----- python/tvm/tensor.py | 12 +- python/tvm/tensor_intrin.py | 18 +- src/lang/tensor.cc | 6 - src/op/{tensor_op.cc => tensor_compute_op.cc} | 96 +++--- src/schedule/schedule_dataflow_rewrite.cc | 290 ++++++++---------- .../unittest/test_schedule_schedule_ops.py | 132 ++++++++ 9 files changed, 344 insertions(+), 316 deletions(-) rename src/op/{tensor_op.cc => tensor_compute_op.cc} (85%) diff --git a/include/tvm/operation.h b/include/tvm/operation.h index f87083303903..48591d316422 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -186,6 +186,8 @@ class TensorComputeOpNode : public OperationNode { public: Array axis; + Array out_axis; + Array tensor_axis; Array reduce_axis; @@ -229,6 +231,7 @@ class TensorComputeOpNode : public OperationNode { v->Visit("name", &name); v->Visit("tag", &tag); v->Visit("axis", &axis); + v->Visit("out_axis", &out_axis); v->Visit("tensor_axis", &tensor_axis); v->Visit("reduce_axis", &reduce_axis); v->Visit("inputs", &inputs); @@ -236,13 +239,13 @@ class TensorComputeOpNode : public OperationNode { static Operation make(std::string name, std::string tag, - Array axis, + Array out_axis, Array tensor_axis, TensorIntrinCall intrin_call); static Operation make(std::string name, std::string tag, - Array axis, + Array out_axis, Array tensor_axis, Array reduce_axis, Array tensors, diff --git a/include/tvm/tensor_intrin.h b/include/tvm/tensor_intrin.h index 6e2db89496d1..f70735b02264 100644 --- a/include/tvm/tensor_intrin.h +++ b/include/tvm/tensor_intrin.h @@ -26,14 +26,6 @@ class TensorIntrin : public NodeRef { */ inline const TensorIntrinNode* operator->() const; - // template - // inline Stmt operator()(Args&& ...args) const { - // Array inputs{std::forward(args)...}; - // return operator()(inputs); - // } - - // TVM_DLL TensorIntrinCall operator()(Array inputs) const; - /*! \brief specify container node */ using ContainerType = TensorIntrinNode; }; diff --git a/python/tvm/api.py b/python/tvm/api.py index 8f5068e06cfa..cc1792fa511d 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -3,7 +3,6 @@ from __future__ import absolute_import as _abs from numbers import Integral as _Integral -from collections import namedtuple from ._ffi.base import string_types from ._ffi.node import register_node, NodeBase @@ -244,6 +243,8 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None): raise ValueError("nested tag is not allowed for now") tag = _tag.TagScope.current.tag shape = (shape,) if isinstance(shape, _expr.Expr) else shape + # for python3 + shape = tuple([int(s) if isinstance(s, float) else s for s in shape]) ndim = len(shape) code = fcompute.__code__ @@ -254,7 +255,6 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None): arg_names = code.co_varnames[:code.co_argcount] out_ndim = code.co_argcount - # TODO check ndim, arg_names if out_ndim != len(arg_names): raise ValueError("fcompute do not match dimension, ndim=%d" % ndim) @@ -264,8 +264,8 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None): if isinstance(body, _tensor.TensorIntrinCall): tensor_var = [] for i, s in enumerate(shape[out_ndim:]): - name = "ax" + str(i) - tensor_var.append(_IterVar((0, s), name, 4)) + var_name = "ax" + str(i) + tensor_var.append(_IterVar((0, s), var_name, 4)) op_node = _api_internal._TensorComputeOp(name, tag, dim_var, @@ -275,7 +275,6 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None): if not isinstance(body, (list, tuple)): body = [body] body = convert(body) - # print('body: {0}'.format(body)) op_node = _api_internal._ComputeOp( name, tag, attrs, dim_var, body) @@ -353,88 +352,6 @@ def scan(init, update, state_placeholder, inputs=None, name="scan", tag="", attr return res[0] if len(res) == 1 else res -def _get_region(tslice): - region = [] - for idx in tslice.indices: - if isinstance(idx, slice): - assert idx.step is None - region.append(Range(idx.start, idx.stop)) - else: - if isinstance(idx, _schedule.IterVar): - begin = idx.var - else: - begin = idx - region.append(_make.range_by_min_extent(begin, 1)) - return region - - -# def tensor_op(out_dims, -# in_dims, # pylint: disable=unused-argument -# finputs, -# intrin, -# raxis=None, -# name='tensor_op', -# tag=""): -# """Construct new tensors with intrinsic. -# -# Parameters -# ---------- -# out_dims: tuple -# The dimensions out of the tensorized region, which can be -# scheduled through `reorder`, `split`. -# -# in_dims: tuple -# The dimensions inside of the tensorized region, which cannot -# be manipulated. -# -# finputs: lambda function of out_dims -> list of TensorSlice -# Specifies involved regions of input tensors. -# -# tensor_intrin : TensorIntrin -# The tensor intrinsic used for computation. -# -# raxis : IterVar -# An iteration variable representing the value. -# -# name: str, optional -# The name hint of the tensor -# -# tag: str, optional -# Additonal tag information about the compute. -# """ -# if _tag.TagScope.current is not None: -# if tag != "": -# raise ValueError("nested tag is not allowed for now") -# tag = _tag.TagScope.current.tag -# -# code = finputs.__code__ -# if finputs.__code__.co_argcount == 0: -# arg_names = ["i%d" % i for i in range(ndim)] -# else: -# arg_names = code.co_varnames[:code.co_argcount] -# -# if len(out_dims) != len(arg_names): -# raise ValueError("finputs do not match dimension, ndim=%d" % out_dims) -# -# out_var = [_IterVar((0, extent), arg_name, 0) -# for arg_name, extent in zip(arg_names, out_dims)] -# if isinstance(raxis, _schedule.IterVar): -# raxis = [raxis] -# if raxis is None: -# raxis = [] -# tensor_regions = finputs(*[v.var for v in out_var]) -# -# op = _api_internal._TensorOp(name, -# tag, -# out_var, -# raxis, -# [x.tensor for x in tensor_regions], -# [_get_region(x) for x in tensor_regions], -# intrin) -# # only support single output -# return op.output(0) - - def extern(shape, inputs, fcompute, diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py index b20a68ac0800..f32b70eb9a12 100644 --- a/python/tvm/tensor.py +++ b/python/tvm/tensor.py @@ -161,6 +161,12 @@ def reduce_axis(self): return self.__getattr__("reduce_axis") +@register_node +class TensorComputeOp(Operation): + """Tensor operation.""" + pass + + @register_node class ScanOp(Operation): """Scan operation.""" @@ -174,9 +180,3 @@ def scan_axis(self): class ExternOp(Operation): """Extern operation.""" pass - - -@register_node -class TensorComputeOp(Operation): - """Tensor operation.""" - pass diff --git a/python/tvm/tensor_intrin.py b/python/tvm/tensor_intrin.py index 818c487d0f35..04c24f21eb40 100644 --- a/python/tvm/tensor_intrin.py +++ b/python/tvm/tensor_intrin.py @@ -6,9 +6,25 @@ from . import stmt as _stmt from . import make as _make from . import tensor as _tensor +from . import schedule as _schedule from .build_module import current_build_config from ._ffi.node import NodeBase, register_node + +def _get_region(tslice): + region = [] + for idx in tslice.indices: + if isinstance(idx, slice): + assert idx.step is None + region.append(_api.Range(idx.start, idx.stop)) + else: + if isinstance(idx, _schedule.IterVar): + begin = idx.var + else: + begin = idx + region.append(_make.range_by_min_extent(begin, 1)) + return region + @register_node class TensorIntrin(NodeBase): """Tensor intrinsic functions for certain computation. @@ -19,7 +35,7 @@ class TensorIntrin(NodeBase): """ def __call__(self, *args, **kwargs): tensors = [x.tensor for x in args] - regions = [_api._get_region(x) for x in args] + regions = [_get_region(x) for x in args] reduce_axis = [] if "reduce_axis" in kwargs: reduce_axis = kwargs["reduce_axis"] diff --git a/src/lang/tensor.cc b/src/lang/tensor.cc index 9544de33a222..9b1a58abcee4 100644 --- a/src/lang/tensor.cc +++ b/src/lang/tensor.cc @@ -37,12 +37,6 @@ Tensor Operation::output(size_t i) const { return Tensor(node); } -// TensorIntrinCall TensorIntrin::operator()(Array inputs) const { -// using HalideIR::Internal::Call; -// LOG(FATAL) << "CallTensorIntrin"; -// CHECK_EQ(tensors.size(), regions.size()); -// } - Tensor TensorNode::make(Array shape, Type dtype, Operation op, diff --git a/src/op/tensor_op.cc b/src/op/tensor_compute_op.cc similarity index 85% rename from src/op/tensor_op.cc rename to src/op/tensor_compute_op.cc index fd9122c8b092..06f9972fd83c 100644 --- a/src/op/tensor_op.cc +++ b/src/op/tensor_compute_op.cc @@ -1,7 +1,7 @@ /*! * Copyright (c) 2017 by Contributors - * \brief Compute Op. - * \file compute_op.cc + * \brief Tensor Compute Op. + * \file tensor_compute_op.cc */ #include #include @@ -29,7 +29,7 @@ int TensorComputeOpNode::num_outputs() const { } Array TensorComputeOpNode::root_iter_vars() const { - Array ret = axis; + Array ret = out_axis; for (IterVar iv : tensor_axis) { ret.push_back(iv); } @@ -45,7 +45,7 @@ Type TensorComputeOpNode::output_dtype(size_t i) const { Array TensorComputeOpNode::output_shape(size_t i) const { Array shape; - for (const auto& ivar : this->axis) { + for (const auto& ivar : this->out_axis) { shape.push_back(ivar->dom->extent); } size_t index = this->inputs.size() + i; @@ -58,12 +58,12 @@ Array TensorComputeOpNode::output_shape(size_t i) const { Operation TensorComputeOpNode::make(std::string name, std::string tag, - Array axis, + Array out_axis, Array tensor_axis, TensorIntrinCall intrin_call) { return TensorComputeOpNode::make(name, tag, - axis, + out_axis, tensor_axis, intrin_call->reduce_axis, intrin_call->tensors, @@ -73,7 +73,7 @@ Operation TensorComputeOpNode::make(std::string name, Operation TensorComputeOpNode::make(std::string name, std::string tag, - Array axis, + Array out_axis, Array tensor_axis, Array reduce_axis, Array tensors, @@ -82,7 +82,15 @@ Operation TensorComputeOpNode::make(std::string name, auto n = make_node(); n->name = name; n->tag = tag; + Array axis; + for (auto iv : out_axis) { + axis.push_back(iv); + } + for (auto iv : tensor_axis) { + axis.push_back(iv); + } n->axis = axis; + n->out_axis = out_axis; n->tensor_axis = tensor_axis; n->reduce_axis = reduce_axis; n->inputs = tensors; @@ -148,10 +156,10 @@ void TensorComputeOpNode::GatherBound( const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const { const TensorDom& tdom = tensor_dom.at(self.output(0)); - for (size_t i = 0; i < this->axis.size(); ++i) { - Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom); - CHECK(!out_dom_map->count(this->axis[i])); - (*out_dom_map)[this->axis[i]] = r; + for (size_t i = 0; i < this->out_axis.size(); ++i) { + Range r = arith::Union(tdom.data.at(i)).cover_range(this->out_axis[i]->dom); + CHECK(!out_dom_map->count(this->out_axis[i])); + (*out_dom_map)[this->out_axis[i]] = r; } for (size_t i = 0; i < this->reduce_axis.size(); ++i) { CHECK(!out_dom_map->count(this->reduce_axis[i])); @@ -165,7 +173,7 @@ Stmt TensorComputeOpNode::BuildRealize( const Stmt& body) const { CHECK_EQ(stage->op.get(), this); HalideIR::Internal::Region bounds; - for (IterVar iv : this->axis) { + for (IterVar iv : this->out_axis) { bounds.push_back(realize_map.at(iv)); } size_t out_buff_idx = this->intrin->buffers.size(); @@ -178,8 +186,8 @@ Stmt TensorComputeOpNode::BuildRealize( realize = ir::Realize::make(t->op, t->value_index, t->dtype, bounds, const_true(), realize); // alignment requirement, only useful for compute - for (size_t i = 0; i < this->axis.size(); ++i) { - auto it = stage->iter_var_attrs.find(this->axis[i]); + for (size_t i = 0; i < this->out_axis.size(); ++i) { + auto it = stage->iter_var_attrs.find(this->out_axis[i]); if (it != stage->iter_var_attrs.end()) { IterVarAttr attr = (*it).second; if (attr->dim_align_factor != 0) { @@ -224,7 +232,7 @@ ComputeLoopNest MakeLoopNest( for (IterVar iv : self->reduce_axis) { update_state[iv] = 2; } - for (IterVar iv : self->axis) { + for (IterVar iv : self->out_axis) { update_state[iv] = 1; } // find which iter var is related to reduction and which is related to axis. @@ -264,21 +272,23 @@ ComputeLoopNest MakeLoopNest( } -Stmt MakeTensorComputeOp(const TensorComputeOpNode* self, - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) { +Stmt TensorComputeOpNode::BuildProvide( + const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const { + CHECK_EQ(stage->op.operator->(), this); + // Start bind data. Stmt nop = Evaluate::make(0); std::vector input_bind_nest, output_bind_nest; - Array inputs = self->InputTensors(); + Array inputs = this->InputTensors(); // input binding size_t num_inputs = inputs.size(); for (size_t i = 0; i < num_inputs; ++i) { Tensor tensor = inputs[i]; - Region region = self->input_regions[i]; - Buffer buffer = self->intrin->buffers[i]; + Region region = this->input_regions[i]; + Buffer buffer = this->intrin->buffers[i]; Array bind_spec{buffer, tensor}; Array tuple; @@ -292,13 +302,13 @@ Stmt MakeTensorComputeOp(const TensorComputeOpNode* self, } // output binding - for (int i = 0; i < self->num_outputs(); ++i) { + for (int i = 0; i < this->num_outputs(); ++i) { Tensor tensor = stage->op.output(i); - Buffer buffer = self->intrin->buffers[num_inputs + i]; + Buffer buffer = this->intrin->buffers[num_inputs + i]; Array bind_spec{buffer, tensor}; Array tuple; - for (const IterVar ivar : self->axis) { + for (const IterVar ivar : this->out_axis) { tuple.push_back(ivar->var); tuple.push_back(1); } @@ -317,16 +327,16 @@ Stmt MakeTensorComputeOp(const TensorComputeOpNode* self, ir::ArgBinder binder(&vmap); size_t tloc = stage->leaf_iter_vars.size(); - ComputeLoopNest n = MakeLoopNest(self, stage, dom_map, debug_keep_trivial_loop); + ComputeLoopNest n = MakeLoopNest(this, stage, dom_map, debug_keep_trivial_loop); - if (self->reduce_axis.size() == 0) { + if (this->reduce_axis.size() == 0) { std::vector > nest( n.main_nest.begin(), n.main_nest.begin() + tloc + 1); nest.emplace_back(op::MakeIfNest(n.main_predicates)); CHECK_EQ(n.init_predicates.size(), 0U); - CHECK(self->intrin->body.defined()) - << "Normal store op for intrin " << self << " is not defined"; - Stmt body = MergeNest(output_bind_nest, self->intrin->body); + CHECK(this->intrin->body.defined()) + << "Normal store op for intrin " << this << " is not defined"; + Stmt body = MergeNest(output_bind_nest, this->intrin->body); body = MergeNest(input_bind_nest, body); body = ir::Substitute(body, vmap); body = MergeNest(binder.asserts(), body); @@ -335,26 +345,26 @@ Stmt MakeTensorComputeOp(const TensorComputeOpNode* self, return ret; } else { // Need to split reduction - CHECK(self->intrin->reduce_update.defined()) + CHECK(this->intrin->reduce_update.defined()) << "Reduction update op is not defined"; // Need init and update steps - CHECK_NE(self->reduce_axis.size(), 0U); + CHECK_NE(this->reduce_axis.size(), 0U); std::vector > common( n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1); std::vector > update_nest( n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1); update_nest.emplace_back(op::MakeIfNest(n.main_predicates)); - if (self->intrin->reduce_init.defined()) { + if (this->intrin->reduce_init.defined()) { // init nest std::vector > init_nest( n.init_nest.begin(), n.init_nest.begin() + tloc + 1); init_nest.emplace_back(op::MakeIfNest(n.init_predicates)); - Stmt init = MergeNest(output_bind_nest, self->intrin->reduce_init); + Stmt init = MergeNest(output_bind_nest, this->intrin->reduce_init); init = op::Substitute(init, n.init_vmap); init = MergeNest(init_nest, init); // The update - Stmt update = MergeNest(output_bind_nest, self->intrin->reduce_update); + Stmt update = MergeNest(output_bind_nest, this->intrin->reduce_update); update = MergeNest(input_bind_nest, update); update = ir::Substitute(update, vmap); update = MergeNest(binder.asserts(), update); @@ -363,11 +373,11 @@ Stmt MakeTensorComputeOp(const TensorComputeOpNode* self, return MergeNest(common, Block::make(init, update)); } else { // When init op is not available, use body op for reset in the first iter. - CHECK(self->intrin->body.defined()) + CHECK(this->intrin->body.defined()) << "Normal body op is not defined"; Stmt update = TransformUpdate(stage, dom_map, n, - self->intrin->body, - self->intrin->reduce_update); + this->intrin->body, + this->intrin->reduce_update); update = MergeNest(output_bind_nest, update); update = MergeNest(input_bind_nest, update); update = ir::Substitute(update, vmap); @@ -379,14 +389,4 @@ Stmt MakeTensorComputeOp(const TensorComputeOpNode* self, } } - -Stmt TensorComputeOpNode::BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const { - CHECK_EQ(stage->op.operator->(), this); - Stmt ret = MakeTensorComputeOp(this, stage, dom_map, debug_keep_trivial_loop); - return ret; -} - } // namespace tvm diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 02704fb56124..20c2228b780d 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -135,30 +135,29 @@ Tensor Schedule::cache_read(const Tensor& tensor, return cache; } - -// for tensor_op -Array CacheWriteWithReLayoutForTensorComputeOp(Schedule sch, - const Array& tensor_array, - const std::string& scope) { - size_t tensor_size = tensor_array.size(); - sch->InvalidateCache(); - Tensor tensor = tensor_array[0]; - Stage orig_stage = sch[tensor->op]; - const TensorComputeOpNode* tensor_op = orig_stage->op.as(); - std::unordered_set red_axis; - for (IterVar iv : tensor_op->reduce_axis) { +template +void PrepareAxisMapping(Stage orig_stage, + OpType* op, + std::unordered_set* p_red_axis, + Array* p_new_axis, + std::unordered_map* p_dom_map, + std::unordered_map* p_vsub, + std::unordered_map* p_vsub2newvar, + std::vector* p_predicates) { + auto& red_axis = *p_red_axis; + auto& new_axis = *p_new_axis; + auto& dom_map = *p_dom_map; + auto& vsub = *p_vsub; + auto& vsub2newvar = *p_vsub2newvar; + auto& predicates = *p_predicates; + + for (IterVar iv : op->reduce_axis) { red_axis.insert(iv); } - std::unordered_map dom_map; - Array new_axis; - - for (IterVar iv : tensor_op->root_iter_vars()) { + for (IterVar iv : op->axis) { dom_map[iv] = iv->dom; } schedule::PassDownDomain(orig_stage, &dom_map, true); - std::unordered_map vsub; - std::unordered_map vsub2newvar; - std::vector predicates; { // The source->cache std::unordered_map value_map; @@ -171,7 +170,7 @@ Array CacheWriteWithReLayoutForTensorComputeOp(Schedule sch, dom, iv->var.copy_with_suffix(".c"), iv->iter_type); new_axis.push_back(new_iv); if (is_one(dom->min)) { - value_map[iv] = dom->extent; + value_map[iv] = dom->min; } else { value_map[iv] = iv->var; vsub2newvar[iv->var.get()] = new_iv->var; @@ -179,75 +178,32 @@ Array CacheWriteWithReLayoutForTensorComputeOp(Schedule sch, } // skip reduction iteration. std::unordered_set skip_bound_check; - for (IterVar iv : tensor_op->reduce_axis) { + for (IterVar iv : op->reduce_axis) { skip_bound_check.insert(iv); } schedule::PassUpIndex(orig_stage, dom_map, &value_map, true); predicates = schedule::MakeBoundCheck( orig_stage, dom_map, value_map, true, skip_bound_check); // The root axis - for (IterVar iv : tensor_op->axis) { - vsub[iv->var.get()] = value_map.at(iv); - } - } - - Array new_regions; - for (Region old_region : tensor_op->input_regions) { - Region region; - for (Range r : old_region) { - Expr min = VarReplacer(vsub2newvar).Mutate(r->min); - Expr extent = VarReplacer(vsub2newvar).Mutate(r->extent); - region.push_back(Range::make_by_min_extent(min, extent)); - } - new_regions.push_back(region); - } - - Operation cache_op = TensorComputeOpNode::make( - tensor_op->name + "." + scope, tensor_op->tag, new_axis, - tensor_op->reduce_axis, tensor_op->tensor_axis, - tensor_op->inputs, new_regions, tensor_op->intrin); - - // axis will be used in generating compute op - Array compute_axis = tensor_op->axis; - for (IterVar iv : tensor_op->tensor_axis) { - // new tensor axis with kDataPar IterVar type - IterVar aiv = IterVarNode::make(iv->dom, iv->var, kDataPar); - compute_axis.push_back(aiv); - } - - // The reader args - Array args; - { - // cache->compute - std::unordered_map value_map; - for (IterVar iv : compute_axis) { - value_map[iv] = iv->var; - } - schedule::PassDownIndex(orig_stage, dom_map, &value_map, true); - for (IterVar iv : orig_stage->leaf_iter_vars) { - if (red_axis.count(iv)) continue; - args.push_back(value_map.at(iv)); - } - // tensorized region axis - for (size_t i = 0; i < tensor_op->tensor_axis.size(); ++i) { - IterVar iv = compute_axis[tensor_op->axis.size() + i]; - args.push_back(value_map.at(iv)); + for (IterVar iv : op->axis) { + if (value_map.count(iv)) { + vsub[iv->var.get()] = value_map.at(iv); + } // to handle tensor axis } } +} +Array ReplaceOriginalOp(Schedule sch, + Stage orig_stage, + const std::string& scope, + Operation cache_op, + Operation orig_new_op, + size_t tensor_size) { Array cache_tensor_list; - Array cache_expr_list; for (size_t i = 0; i < tensor_size; i++) { Tensor cache_tensor = cache_op.output(i); cache_tensor_list.push_back(cache_tensor); - cache_expr_list.push_back(cache_tensor(args)); } - - Operation orig_new_op = ComputeOpNode::make( - tensor_op->name, tensor_op->tag, {}, - compute_axis, cache_expr_list); - - // The replace of the dataflow std::unordered_map vmap; std::unordered_map rvmap; @@ -283,61 +239,24 @@ Array CacheWriteWithReLayoutForTensorComputeOp(Schedule sch, // Cache write and relayout the data according to loop pattern Array CacheWriteWithReLayout(Schedule sch, - const Array& tensor_array, - const std::string& scope) { + const Array& tensor_array, + const std::string& scope) { size_t tensor_size = tensor_array.size(); sch->InvalidateCache(); Tensor tensor = tensor_array[0]; Stage orig_stage = sch[tensor->op]; - if (!strcmp(orig_stage->op->type_key(), "TensorComputeOp")) { - return CacheWriteWithReLayoutForTensorComputeOp(sch, tensor_array, scope); - } const ComputeOpNode* compute = orig_stage->op.as(); + std::unordered_set red_axis; - for (IterVar iv : compute->reduce_axis) { - red_axis.insert(iv); - } - std::unordered_map dom_map; Array new_axis; + std::unordered_map dom_map; - for (IterVar iv : compute->axis) { - dom_map[iv] = iv->dom; - } - schedule::PassDownDomain(orig_stage, &dom_map, true); std::unordered_map vsub; std::unordered_map vsub2newvar; std::vector predicates; - { - // The source->cache - std::unordered_map value_map; - for (IterVar iv : orig_stage->leaf_iter_vars) { - if (red_axis.count(iv)) continue; - CHECK_EQ(iv->iter_type, kDataPar) - << "Can only relayout with in data parallel dimensions"; - Range dom = dom_map.at(iv); - IterVar new_iv = IterVarNode::make( - dom, iv->var.copy_with_suffix(".c"), iv->iter_type); - new_axis.push_back(new_iv); - if (is_one(dom->min)) { - value_map[iv] = dom->min; - } else { - value_map[iv] = iv->var; - vsub2newvar[iv->var.get()] = new_iv->var; - } - } - // skip reduction iteration. - std::unordered_set skip_bound_check; - for (IterVar iv : compute->reduce_axis) { - skip_bound_check.insert(iv); - } - schedule::PassUpIndex(orig_stage, dom_map, &value_map, true); - predicates = schedule::MakeBoundCheck( - orig_stage, dom_map, value_map, true, skip_bound_check); - // The root axis - for (IterVar iv : compute->axis) { - vsub[iv->var.get()] = value_map.at(iv); - } - } + + PrepareAxisMapping(orig_stage, compute, + &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates); Expr body; Array body_list; @@ -383,46 +302,98 @@ Array CacheWriteWithReLayout(Schedule sch, Operation cache_op = ComputeOpNode::make( compute->name + "." + scope, compute->tag, compute->attrs, new_axis, body_list); - Array cache_tensor_list; + Array cache_expr_list; for (size_t i = 0; i < tensor_size; i++) { Tensor cache_tensor = cache_op.output(i); - cache_tensor_list.push_back(cache_tensor); cache_expr_list.push_back(cache_tensor(args)); } Operation orig_new_op = ComputeOpNode::make( compute->name, compute->tag, compute->attrs, compute->axis, cache_expr_list); - // The replace of the dataflow - std::unordered_map vmap; - std::unordered_map rvmap; - vmap[orig_stage->op.output(0)] = orig_new_op.output(0); - rvmap[orig_new_op.output(0)] = orig_stage->op.output(0); - for (size_t i = 0; i < tensor_size; i++) { - vmap[orig_stage->op.output(0)] = orig_new_op.output(0); - rvmap[orig_new_op.output(0)] = orig_stage->op.output(0); + return ReplaceOriginalOp(sch, orig_stage, scope, + cache_op, orig_new_op, tensor_size); +} + + +// for tensor compute op +Array CacheWriteWithReLayoutTensor(Schedule sch, + const Array& tensor_array, + const std::string& scope) { + size_t tensor_size = tensor_array.size(); + sch->InvalidateCache(); + Tensor tensor = tensor_array[0]; + Stage orig_stage = sch[tensor->op]; + const TensorComputeOpNode* tensor_op = orig_stage->op.as(); + CHECK_EQ(tensor_op->num_outputs(), 1) + << "cache write only support single output tensor_compute_op"; + + std::unordered_set red_axis; + Array new_axis; + std::unordered_map dom_map; + + std::unordered_map vsub; + std::unordered_map vsub2newvar; + std::vector predicates; + + PrepareAxisMapping(orig_stage, tensor_op, + &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates); + + + Array new_regions; + for (Region old_region : tensor_op->input_regions) { + Region region; + for (Range r : old_region) { + Expr min = VarReplacer(vsub2newvar).Mutate(r->min); + Expr extent = VarReplacer(vsub2newvar).Mutate(r->extent); + region.push_back(Range::make_by_min_extent(min, extent)); + } + new_regions.push_back(region); } - ReplaceDataFlow(sch->stages, &vmap, &rvmap); - // mutate orig stage - orig_stage->op = orig_new_op; - orig_stage->all_iter_vars = orig_stage->op->root_iter_vars(); - orig_stage->leaf_iter_vars = orig_stage->all_iter_vars; - orig_stage->relations = Array(); - // create schedule for new cached stage. - ArrayNode* stages = sch->stages.CopyOnWrite(); - size_t pos = FindNodeRef(stages, orig_stage); - Stage cache_stage = Stage(cache_op); - cache_stage.set_scope(scope); - CHECK_LT(pos, stages->data.size()); - stages->data.insert(stages->data.begin() + pos, - cache_stage.node_); - sch->stage_map.Set(cache_op, cache_stage); - // Update group - cache_stage->group = orig_stage->group; - if (cache_stage->group.defined()) { - ++cache_stage->group->num_child_stages; + + Operation cache_op = TensorComputeOpNode::make( + tensor_op->name + "." + scope, tensor_op->tag, new_axis, + tensor_op->tensor_axis, tensor_op->reduce_axis, + tensor_op->inputs, new_regions, tensor_op->intrin); + + // axis will be used in generating compute op + Array compute_axis = tensor_op->out_axis; + for (IterVar iv : tensor_op->tensor_axis) { + // new tensor axis with kDataPar IterVar type + IterVar aiv = IterVarNode::make(iv->dom, iv->var, kDataPar); + compute_axis.push_back(aiv); } - return cache_tensor_list; + + // The reader args + Array args; + { + // cache->compute + std::unordered_map value_map; + for (IterVar iv : compute_axis) { + value_map[iv] = iv->var; + } + schedule::PassDownIndex(orig_stage, dom_map, &value_map, true); + for (IterVar iv : orig_stage->leaf_iter_vars) { + if (red_axis.count(iv)) continue; + args.push_back(value_map.at(iv)); + } + // tensorized region axis + for (size_t i = 0; i < tensor_op->tensor_axis.size(); ++i) { + IterVar iv = compute_axis[tensor_op->out_axis.size() + i]; + args.push_back(value_map.at(iv)); + } + } + + Array cache_expr_list; + for (size_t i = 0; i < tensor_size; i++) { + Tensor cache_tensor = cache_op.output(i); + cache_expr_list.push_back(cache_tensor(args)); + } + Operation orig_new_op = ComputeOpNode::make( + tensor_op->name, tensor_op->tag, {}, + compute_axis, cache_expr_list); + return ReplaceOriginalOp(sch, orig_stage, scope, + cache_op, orig_new_op, tensor_size); } @@ -441,23 +412,26 @@ Array Schedule::cache_write(const Array& tensor_array, CHECK(orig_stage.same_as(tmp_stage)) << "Input tensor list must be generated by ONE computeOp"; } - return CacheWriteWithReLayout(*this, tensor_array, scope); } + Tensor Schedule::cache_write(const Tensor& tensor, const std::string& scope) { + // support original compute and tensor compute both (*this)->InvalidateCache(); - Stage orig_stage = operator[](tensor->op); - // const ComputeOpNode* compute = tensor->op.as(); - // CHECK(compute) - // << "cache write only take ComputeOp as writers"; - // CHECK_EQ(compute->num_outputs(), 1) - // << "cache write only support single output ComputeOp"; - - return (CacheWriteWithReLayout(*this, {tensor}, scope))[0]; + const char* type_key = tensor->op->type_key(); + if (!strcmp(type_key, "ComputeOp")) { + return (CacheWriteWithReLayout(*this, {tensor}, scope))[0]; + } else if (!strcmp(type_key, "TensorComputeOp")) { + return (CacheWriteWithReLayoutTensor(*this, {tensor}, scope))[0]; + } else { + LOG(FATAL) << "cache write only take ComputeOp or TensorComputeOp as writers"; + return Tensor(); + } } + void RebaseNonZeroMinLoop(const Schedule& sch) { std::unordered_map rebase_map; for (Stage s : sch->stages) { diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index 8e6f4090d403..ef1babbec72d 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -276,6 +276,135 @@ def test_schedule_bound_condition(): stmt = tvm.ir_pass.Simplify(stmt) assert (isinstance(stmt.body.body.first.body.body.then_case, tvm.stmt.IfThenElse)) + +def intrin_gemv(m, n): + w = tvm.placeholder((m, n), name='w') + x = tvm.placeholder((n,), name='x') + k = tvm.reduce_axis((0, n), name='k') + z = tvm.compute((m,), lambda i: + tvm.sum(w[i, k] * x[k], axis=k), name='z') + Wb = tvm.decl_buffer(w.shape, w.dtype, + name="W", + offset_factor=16, + strides=[tvm.var('ldw'), 1]) + def intrin_func(ins, outs): + ww, xx = ins + zz = outs[0] + ww_ptr = ww.access_ptr("r") + xx_ptr = xx.access_ptr("r") + zz_ptr = zz.access_ptr("w") + body = tvm.call_packed( + "gemm", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) + reset = tvm.call_packed( + "fill_zero", zz_ptr, n) + update = tvm.call_packed( + "gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) + return body, reset, update + + with tvm.build_config(data_alignment=16, + offset_factor=16): + return tvm.decl_tensor_intrin(z.op, intrin_func, + binds={w: Wb}) + + +def test_schedule_tensor_compute1(): + # basic: split, reorder, tile + M, N, L = 2048, 1024, 512 + factor, rfactor = 16, 16 + A = tvm.placeholder((N/factor, L/rfactor, factor, rfactor), name='A') + B = tvm.placeholder((M, L/rfactor, rfactor), name='B') + k = tvm.reduce_axis((0, L/rfactor), name='k') + + gemv = intrin_gemv(factor, rfactor) + C = tvm.compute((N, M/factor, factor), + lambda i, j: gemv(A[i, k, 0:factor, 0:factor], B[j, k, 0:rfactor], reduce_axis=k), + name='C') + + s = tvm.create_schedule(C.op) + ai, aj, ax = s[C].op.axis + aio, aii = s[C].split(ai, 16) + s[C].reorder(aio, aj, aii) + aioo, ajo, aioi, aji = s[C].tile(aio, aj, 16, 4) + + print(tvm.lower(s, [A, B, C], simple_mode=True)) + s = s.normalize() + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + + +def intrin_vadd(n, cache_read=False, cache_write=False): + scope_ubuf = 'local' + dtype = 'float32' + x = tvm.placeholder((n,), dtype=dtype, name='vx') + y = tvm.placeholder((n,), dtype=dtype, name='vy') + z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z') + s = tvm.create_schedule(z.op) + + def create_buffer(t): + return tvm.decl_buffer(t.shape, t.dtype, + name='W'+t.name, + scope=scope_ubuf, + offset_factor=16) + + binds = {} + if cache_read: + binds[x] = create_buffer(x) + binds[y] = create_buffer(y) + if cache_write: + binds[z] = create_buffer(z) + + def intrin_func(ins, outs): + ib = tvm.ir_builder.create() + ib.emit(tvm.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr'))) + return ib.get() + + with tvm.build_config(offset_factor=16): + return tvm.decl_tensor_intrin(z.op, intrin_func, binds=binds) + + +def test_schedule_tensor_compute2(): + # cache_read, cache_write + M = 1024 + factor = 16 + dtype = 'float32' + scope_ubuf = 'local' + + A = tvm.placeholder((M/factor, factor), name="A", dtype=dtype) + B = tvm.placeholder((M/factor, factor), name="B", dtype=dtype) + + vadd = intrin_vadd(factor, True, True) + C = tvm.compute((M/factor, factor), + lambda i: vadd(A[i, 0:factor], B[i, 0:factor]), name='C') + + s = tvm.create_schedule(C.op) + AL = s.cache_read(A, scope_ubuf, C) + BL = s.cache_read(B, scope_ubuf, C) + CL = s.cache_write(C, scope_ubuf) + print(tvm.lower(s, [A, B, C], simple_mode=True)) + s = s.normalize() + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + + +def test_schedule_tensor_compute3(): + M = 1024 + factor = 16 + dtype = 'float32' + A = tvm.placeholder((M/factor, factor), name="A", dtype=dtype) + B = tvm.placeholder((M/factor, factor), name="B", dtype=dtype) + Bi = tvm.compute((M/factor, factor), lambda i, j: B[i, j] + 5, name="Bi") + + vadd = intrin_vadd(factor) + C = tvm.compute((M/factor, factor), + lambda i: vadd(A[i, 0:factor], Bi[i, 0:factor]), name='C') + s = tvm.create_schedule(C.op) + s[Bi].compute_at(s[C], C.op.axis[0]) + print(tvm.lower(s, [A, B, C], simple_mode=True)) + s = s.normalize() + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + + if __name__ == "__main__": test_schedule_middle_cache() test_inline_multi_reduce() @@ -294,3 +423,6 @@ def test_schedule_bound_condition(): test_schedule2() test_schedule_cache() test_schedule_bound_condition() + test_schedule_tensor_compute1() + test_schedule_tensor_compute2() + test_schedule_tensor_compute3() From 81ffb68ad48449738e7dc4b89d6ac7a5cd400a72 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Sun, 23 Sep 2018 21:25:24 -0700 Subject: [PATCH 5/7] [TensorOp] Fix. --- tests/python/unittest/test_lang_tensor.py | 14 ++++++------ .../unittest/test_schedule_schedule_ops.py | 22 +++++++++---------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py index aeccf9cddab6..2f49b084b875 100644 --- a/tests/python/unittest/test_lang_tensor.py +++ b/tests/python/unittest/test_lang_tensor.py @@ -104,9 +104,9 @@ def intrin_func(ins, outs): vadd = intrin_vadd(factor) - A = tvm.placeholder((m/factor, factor), name="A", dtype=dtype) - B = tvm.placeholder((m/factor, factor), name="B", dtype=dtype) - C = tvm.compute((m/factor, factor), + A = tvm.placeholder((m//factor, factor), name="A", dtype=dtype) + B = tvm.placeholder((m//factor, factor), name="B", dtype=dtype) + C = tvm.compute((m//factor, factor), lambda i: vadd(A[i, 0:factor], B[i, 0:factor])) s = tvm.create_schedule(C.op) @@ -146,10 +146,10 @@ def intrin_func(ins, outs): vgemm = intrin_gemm(factor1, factor2, factor) - A = tvm.placeholder((M/factor1, L/factor, factor1, factor), name="A", dtype=dtype) - B = tvm.placeholder((N/factor2, L/factor, factor2, factor), name="B", dtype=dtype) - k = tvm.reduce_axis((0, L/factor), name='k') - C = tvm.compute((M/factor1, N/factor2, factor1, factor2), + A = tvm.placeholder((M//factor1, L//factor, factor1, factor), name="A", dtype=dtype) + B = tvm.placeholder((N//factor2, L//factor, factor2, factor), name="B", dtype=dtype) + k = tvm.reduce_axis((0, L//factor), name='k') + C = tvm.compute((M//factor1, N//factor2, factor1, factor2), lambda i, j: vgemm(A[i, k, 0:factor1, 0:factor], B[j, k, 0:factor2, 0:factor], reduce_axis=k)) s = tvm.create_schedule(C.op) diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index ef1babbec72d..72426cfa41a7 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -311,12 +311,12 @@ def test_schedule_tensor_compute1(): # basic: split, reorder, tile M, N, L = 2048, 1024, 512 factor, rfactor = 16, 16 - A = tvm.placeholder((N/factor, L/rfactor, factor, rfactor), name='A') - B = tvm.placeholder((M, L/rfactor, rfactor), name='B') - k = tvm.reduce_axis((0, L/rfactor), name='k') + A = tvm.placeholder((N//factor, L//rfactor, factor, rfactor), name='A') + B = tvm.placeholder((M, L//rfactor, rfactor), name='B') + k = tvm.reduce_axis((0, L//rfactor), name='k') gemv = intrin_gemv(factor, rfactor) - C = tvm.compute((N, M/factor, factor), + C = tvm.compute((N, M//factor, factor), lambda i, j: gemv(A[i, k, 0:factor, 0:factor], B[j, k, 0:rfactor], reduce_axis=k), name='C') @@ -369,11 +369,11 @@ def test_schedule_tensor_compute2(): dtype = 'float32' scope_ubuf = 'local' - A = tvm.placeholder((M/factor, factor), name="A", dtype=dtype) - B = tvm.placeholder((M/factor, factor), name="B", dtype=dtype) + A = tvm.placeholder((M//factor, factor), name="A", dtype=dtype) + B = tvm.placeholder((M//factor, factor), name="B", dtype=dtype) vadd = intrin_vadd(factor, True, True) - C = tvm.compute((M/factor, factor), + C = tvm.compute((M//factor, factor), lambda i: vadd(A[i, 0:factor], B[i, 0:factor]), name='C') s = tvm.create_schedule(C.op) @@ -390,12 +390,12 @@ def test_schedule_tensor_compute3(): M = 1024 factor = 16 dtype = 'float32' - A = tvm.placeholder((M/factor, factor), name="A", dtype=dtype) - B = tvm.placeholder((M/factor, factor), name="B", dtype=dtype) - Bi = tvm.compute((M/factor, factor), lambda i, j: B[i, j] + 5, name="Bi") + A = tvm.placeholder((M//factor, factor), name="A", dtype=dtype) + B = tvm.placeholder((M//factor, factor), name="B", dtype=dtype) + Bi = tvm.compute((M//factor, factor), lambda i, j: B[i, j] + 5, name="Bi") vadd = intrin_vadd(factor) - C = tvm.compute((M/factor, factor), + C = tvm.compute((M//factor, factor), lambda i: vadd(A[i, 0:factor], Bi[i, 0:factor]), name='C') s = tvm.create_schedule(C.op) s[Bi].compute_at(s[C], C.op.axis[0]) From 2cd02e0b864c05a77bb39844881a600b25058af9 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Wed, 3 Oct 2018 16:21:19 -0700 Subject: [PATCH 6/7] [TensorOp] Remove 'out_axis', 'tensor_axis' fields. --- include/tvm/operation.h | 21 ++-- python/tvm/api.py | 10 +- src/api/api_lang.cc | 5 +- src/op/tensor_compute_op.cc | 107 ++++++++---------- src/schedule/schedule_dataflow_rewrite.cc | 20 ++-- tests/python/unittest/test_lang_tensor.py | 2 + .../unittest/test_schedule_schedule_ops.py | 1 + 7 files changed, 81 insertions(+), 85 deletions(-) diff --git a/include/tvm/operation.h b/include/tvm/operation.h index 48591d316422..013aef3e016e 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -186,12 +186,14 @@ class TensorComputeOpNode : public OperationNode { public: Array axis; - Array out_axis; + // Array out_axis; - Array tensor_axis; + // Array tensor_axis; Array reduce_axis; + int sch_ndim; + Array inputs; Array input_regions; @@ -231,23 +233,18 @@ class TensorComputeOpNode : public OperationNode { v->Visit("name", &name); v->Visit("tag", &tag); v->Visit("axis", &axis); - v->Visit("out_axis", &out_axis); - v->Visit("tensor_axis", &tensor_axis); v->Visit("reduce_axis", &reduce_axis); + v->Visit("sch_ndim", &sch_ndim); v->Visit("inputs", &inputs); + v->Visit("input_regions", &input_regions); + v->Visit("intrin", &intrin); } static Operation make(std::string name, std::string tag, - Array out_axis, - Array tensor_axis, - TensorIntrinCall intrin_call); - - static Operation make(std::string name, - std::string tag, - Array out_axis, - Array tensor_axis, + Array axis, Array reduce_axis, + int sch_ndim, Array tensors, Array regions, TensorIntrin intrin); diff --git a/python/tvm/api.py b/python/tvm/api.py index cc1792fa511d..f1a96c14d61b 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -262,15 +262,17 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None): body = fcompute(*[v.var for v in dim_var]) if isinstance(body, _tensor.TensorIntrinCall): - tensor_var = [] for i, s in enumerate(shape[out_ndim:]): var_name = "ax" + str(i) - tensor_var.append(_IterVar((0, s), var_name, 4)) + dim_var.append(_IterVar((0, s), var_name, 4)) op_node = _api_internal._TensorComputeOp(name, tag, dim_var, - tensor_var, - body) + body.reduce_axis, + out_ndim, + body.tensors, + body.regions, + body.intrin) else: if not isinstance(body, (list, tuple)): body = [body] diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 67377e468a4f..75365da5bf50 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -292,7 +292,10 @@ TVM_REGISTER_API("_TensorComputeOp") args[1], args[2], args[3], - args[4]); + args[4], + args[5], + args[6], + args[7]); }); TVM_REGISTER_API("_ExternOp") diff --git a/src/op/tensor_compute_op.cc b/src/op/tensor_compute_op.cc index 06f9972fd83c..345f3c97cbc4 100644 --- a/src/op/tensor_compute_op.cc +++ b/src/op/tensor_compute_op.cc @@ -29,10 +29,7 @@ int TensorComputeOpNode::num_outputs() const { } Array TensorComputeOpNode::root_iter_vars() const { - Array ret = out_axis; - for (IterVar iv : tensor_axis) { - ret.push_back(iv); - } + Array ret = axis; for (IterVar iv : reduce_axis) { ret.push_back(iv); } @@ -45,57 +42,45 @@ Type TensorComputeOpNode::output_dtype(size_t i) const { Array TensorComputeOpNode::output_shape(size_t i) const { Array shape; - for (const auto& ivar : this->out_axis) { + for (const auto& ivar : this->axis) { shape.push_back(ivar->dom->extent); } - size_t index = this->inputs.size() + i; - for (const auto& dim : this->intrin->buffers[index]->shape) { - shape.push_back(dim); - } return shape; } -Operation TensorComputeOpNode::make(std::string name, - std::string tag, - Array out_axis, - Array tensor_axis, - TensorIntrinCall intrin_call) { - return TensorComputeOpNode::make(name, - tag, - out_axis, - tensor_axis, - intrin_call->reduce_axis, - intrin_call->tensors, - intrin_call->regions, - intrin_call->intrin); -} +// Operation TensorComputeOpNode::make(std::string name, +// std::string tag, +// Array out_axis, +// Array tensor_axis, +// TensorIntrinCall intrin_call) { +// return TensorComputeOpNode::make(name, +// tag, +// out_axis, +// tensor_axis, +// intrin_call->reduce_axis, +// intrin_call->tensors, +// intrin_call->regions, +// intrin_call->intrin); +// } Operation TensorComputeOpNode::make(std::string name, std::string tag, - Array out_axis, - Array tensor_axis, + Array axis, Array reduce_axis, + int sch_ndim, Array tensors, Array regions, TensorIntrin intrin) { auto n = make_node(); - n->name = name; - n->tag = tag; - Array axis; - for (auto iv : out_axis) { - axis.push_back(iv); - } - for (auto iv : tensor_axis) { - axis.push_back(iv); - } - n->axis = axis; - n->out_axis = out_axis; - n->tensor_axis = tensor_axis; - n->reduce_axis = reduce_axis; - n->inputs = tensors; - n->input_regions = regions; - n->intrin = intrin; + n->name = std::move(name); + n->tag = std::move(tag); + n->axis = std::move(axis); + n->reduce_axis = std::move(reduce_axis); + n->sch_ndim = sch_ndim; + n->inputs = std::move(tensors); + n->input_regions = std::move(regions); + n->intrin = std::move(intrin); return Operation(n); } @@ -156,11 +141,12 @@ void TensorComputeOpNode::GatherBound( const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const { const TensorDom& tdom = tensor_dom.at(self.output(0)); - for (size_t i = 0; i < this->out_axis.size(); ++i) { - Range r = arith::Union(tdom.data.at(i)).cover_range(this->out_axis[i]->dom); - CHECK(!out_dom_map->count(this->out_axis[i])); - (*out_dom_map)[this->out_axis[i]] = r; + for (size_t i = 0; i < this->axis.size(); ++i) { + Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom); + CHECK(!out_dom_map->count(this->axis[i])); + (*out_dom_map)[this->axis[i]] = r; } + // should I add dom of tensor_vars for (size_t i = 0; i < this->reduce_axis.size(); ++i) { CHECK(!out_dom_map->count(this->reduce_axis[i])); (*out_dom_map)[this->reduce_axis[i]] = this->reduce_axis[i]->dom; @@ -173,21 +159,17 @@ Stmt TensorComputeOpNode::BuildRealize( const Stmt& body) const { CHECK_EQ(stage->op.get(), this); HalideIR::Internal::Region bounds; - for (IterVar iv : this->out_axis) { + for (IterVar iv : this->axis) { bounds.push_back(realize_map.at(iv)); } - size_t out_buff_idx = this->intrin->buffers.size(); - for (const Expr extent : this->intrin->buffers[out_buff_idx - 1]->shape) { - bounds.push_back(Range(0, extent)); - } Stmt realize = body; for (int i = this->num_outputs(); i > 0; --i) { Tensor t = stage->op.output(i-1); realize = ir::Realize::make(t->op, t->value_index, t->dtype, bounds, const_true(), realize); // alignment requirement, only useful for compute - for (size_t i = 0; i < this->out_axis.size(); ++i) { - auto it = stage->iter_var_attrs.find(this->out_axis[i]); + for (int i = 0; i < sch_ndim; ++i) { + auto it = stage->iter_var_attrs.find(this->axis[i]); if (it != stage->iter_var_attrs.end()) { IterVarAttr attr = (*it).second; if (attr->dim_align_factor != 0) { @@ -232,8 +214,8 @@ ComputeLoopNest MakeLoopNest( for (IterVar iv : self->reduce_axis) { update_state[iv] = 2; } - for (IterVar iv : self->out_axis) { - update_state[iv] = 1; + for (int i = 0; i < self->sch_ndim; ++i) { + update_state[self->axis[i]] = 1; } // find which iter var is related to reduction and which is related to axis. schedule::PassDownBitMaskOr(stage, &update_state); @@ -308,13 +290,16 @@ Stmt TensorComputeOpNode::BuildProvide( Array bind_spec{buffer, tensor}; Array tuple; - for (const IterVar ivar : this->out_axis) { - tuple.push_back(ivar->var); - tuple.push_back(1); - } - for (const Expr extent : buffer->shape) { - tuple.push_back(0); - tuple.push_back(extent); + for (size_t i = 0; i < this->axis.size(); ++i) { + auto ivar = this->axis[i]; + if (i < static_cast(this->sch_ndim)) { + tuple.push_back(ivar->var); + tuple.push_back(1); + } else { + Range dom = ivar->dom; + tuple.push_back(dom->min); + tuple.push_back(dom->extent); + } } output_bind_nest.emplace_back(AttrStmt::make( diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 20c2228b780d..1ebc8d1f3b7e 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -340,6 +340,12 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates); + for (int i = tensor_op->sch_ndim; i < static_cast(tensor_op->axis.size()); ++i) { + IterVar iv = tensor_op->axis[i]; + IterVar new_iv = IterVarNode::make( + iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type); + new_axis.push_back(new_iv); + } Array new_regions; for (Region old_region : tensor_op->input_regions) { Region region; @@ -353,15 +359,15 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, Operation cache_op = TensorComputeOpNode::make( tensor_op->name + "." + scope, tensor_op->tag, new_axis, - tensor_op->tensor_axis, tensor_op->reduce_axis, + tensor_op->reduce_axis, tensor_op->sch_ndim, tensor_op->inputs, new_regions, tensor_op->intrin); // axis will be used in generating compute op - Array compute_axis = tensor_op->out_axis; - for (IterVar iv : tensor_op->tensor_axis) { - // new tensor axis with kDataPar IterVar type + Array compute_axis = tensor_op->axis; + for (size_t i = tensor_op->sch_ndim; i < tensor_op->axis.size(); ++i) { + IterVar iv = tensor_op->axis[i]; IterVar aiv = IterVarNode::make(iv->dom, iv->var, kDataPar); - compute_axis.push_back(aiv); + compute_axis.Set(i, aiv); } // The reader args @@ -378,8 +384,8 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, args.push_back(value_map.at(iv)); } // tensorized region axis - for (size_t i = 0; i < tensor_op->tensor_axis.size(); ++i) { - IterVar iv = compute_axis[tensor_op->out_axis.size() + i]; + for (size_t i = tensor_op->sch_ndim; i < tensor_op->axis.size(); ++i) { + IterVar iv = compute_axis[i]; args.push_back(value_map.at(iv)); } } diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py index 2f49b084b875..6d1515f1219f 100644 --- a/tests/python/unittest/test_lang_tensor.py +++ b/tests/python/unittest/test_lang_tensor.py @@ -111,6 +111,7 @@ def intrin_func(ins, outs): s = tvm.create_schedule(C.op) stmt = tvm.lower(s, [A, B, C], simple_mode=True) + print(stmt) assert isinstance(stmt.body.body, tvm.stmt.Evaluate) def test_tensor_compute2(): @@ -154,6 +155,7 @@ def intrin_func(ins, outs): s = tvm.create_schedule(C.op) stmt = tvm.lower(s, [A, B, C], simple_mode=True) + print(stmt) assert isinstance(stmt.body.body.body.first, tvm.stmt.Evaluate) assert isinstance(stmt.body.body.body.rest.body, tvm.stmt.Evaluate) diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index 72426cfa41a7..9bd188ece68d 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -387,6 +387,7 @@ def test_schedule_tensor_compute2(): def test_schedule_tensor_compute3(): + # compute_at M = 1024 factor = 16 dtype = 'float32' From 714e6f1d9223a8a454736672344b04b4c84ba4b3 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Thu, 4 Oct 2018 12:10:41 -0700 Subject: [PATCH 7/7] [TensorOp] Update doc. --- include/tvm/operation.h | 79 +++++++++---------- include/tvm/tensor_intrin.h | 12 ++- python/tvm/api.py | 4 +- src/op/tensor_compute_op.cc | 32 ++------ src/schedule/schedule_dataflow_rewrite.cc | 10 +-- tests/python/unittest/test_lang_tensor.py | 2 - .../unittest/test_schedule_schedule_ops.py | 3 - 7 files changed, 62 insertions(+), 80 deletions(-) diff --git a/include/tvm/operation.h b/include/tvm/operation.h index 013aef3e016e..1a1d28ab71bb 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -182,27 +182,19 @@ class PlaceholderOpNode : public OperationNode { TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode, OperationNode); }; -class TensorComputeOpNode : public OperationNode { +/*! + * \brief A Compute op that compute a tensor on certain domain. + */ +class TVM_DLL ComputeOpNode : public OperationNode { public: + /*! \brief IterVar on each axis */ Array axis; - - // Array out_axis; - - // Array tensor_axis; - + /*! \brief IterVar on each reduction axis, if the body is a Reduce */ Array reduce_axis; - - int sch_ndim; - - Array inputs; - - Array input_regions; - - TensorIntrin intrin; - + /*! \brief the compute expression */ + Array body; /*! \brief constructor */ - TensorComputeOpNode() {} - + ComputeOpNode() {} // override functions int num_outputs() const final; Array root_iter_vars() const final; @@ -232,40 +224,40 @@ class TensorComputeOpNode : public OperationNode { void VisitAttrs(AttrVisitor* v) final { v->Visit("name", &name); v->Visit("tag", &tag); + v->Visit("attrs", &attrs); v->Visit("axis", &axis); v->Visit("reduce_axis", &reduce_axis); - v->Visit("sch_ndim", &sch_ndim); - v->Visit("inputs", &inputs); - v->Visit("input_regions", &input_regions); - v->Visit("intrin", &intrin); + v->Visit("body", &body); } - static Operation make(std::string name, std::string tag, + Map attrs, Array axis, - Array reduce_axis, - int sch_ndim, - Array tensors, - Array regions, - TensorIntrin intrin); + Array body); - static constexpr const char* _type_key = "TensorComputeOp"; - TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, OperationNode); + static constexpr const char* _type_key = "ComputeOp"; + TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, OperationNode); }; /*! - * \brief A Compute op that compute a tensor on certain domain. + * \brief A TenorCompute op that compute a tensor with an tensor intrinsic. */ -class TVM_DLL ComputeOpNode : public OperationNode { +class TensorComputeOpNode : public OperationNode { public: /*! \brief IterVar on each axis */ Array axis; - /*! \brief IterVar on each reduction axis, if the body is a Reduce */ + /*! \brief IterVar on each reduction axis, if the intrin will use the reduce axis */ Array reduce_axis; - /*! \brief the compute expression */ - Array body; + /*! \brief number of axes that can be scheduled */ + int schedulable_ndim; + /*! \brief TensorIntrin used to compute */ + TensorIntrin intrin; + /*! \brief input tensors of intrin */ + Array inputs; + /*! \brief region of input tensors */ + Array input_regions; /*! \brief constructor */ - ComputeOpNode() {} + TensorComputeOpNode() {} // override functions int num_outputs() const final; Array root_iter_vars() const final; @@ -295,19 +287,24 @@ class TVM_DLL ComputeOpNode : public OperationNode { void VisitAttrs(AttrVisitor* v) final { v->Visit("name", &name); v->Visit("tag", &tag); - v->Visit("attrs", &attrs); v->Visit("axis", &axis); v->Visit("reduce_axis", &reduce_axis); - v->Visit("body", &body); + v->Visit("schedulable_ndim", &schedulable_ndim); + v->Visit("intrin", &intrin); + v->Visit("inputs", &inputs); + v->Visit("input_regions", &input_regions); } static Operation make(std::string name, std::string tag, - Map attrs, Array axis, - Array body); + Array reduce_axis, + int schedulable_ndim, + TensorIntrin intrin, + Array tensors, + Array regions); - static constexpr const char* _type_key = "ComputeOp"; - TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, OperationNode); + static constexpr const char* _type_key = "TensorComputeOp"; + TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, OperationNode); }; /*! diff --git a/include/tvm/tensor_intrin.h b/include/tvm/tensor_intrin.h index f70735b02264..fbee4bccc0bf 100644 --- a/include/tvm/tensor_intrin.h +++ b/include/tvm/tensor_intrin.h @@ -91,10 +91,10 @@ inline const TensorIntrinNode* TensorIntrin::operator->() const { } -// Internal node container of tensor intrinsics. +// Internal node container of tensor intrinsic calling. class TensorIntrinCallNode; -/*! \brief Tensor intrinsic node. */ +/*! \brief Tensor intrinsic calling node. */ class TensorIntrinCall : public NodeRef { public: TensorIntrinCall() {} @@ -111,9 +111,16 @@ class TensorIntrinCall : public NodeRef { class TensorIntrinCallNode : public Node { public: + /*! \brief the tensor intrinsic */ TensorIntrin intrin; + /*! \brief input tensors of the intrinsic */ Array tensors; + /*! \brief regions of input tensors */ Array regions; + /*! + * \brief IterVar on each reduction axis, if the + * intrin will use the reduce axis + */ Array reduce_axis; void VisitAttrs(AttrVisitor* v) final { @@ -122,7 +129,6 @@ class TensorIntrinCallNode : public Node { v->Visit("regions", ®ions); v->Visit("reduce_axis", &reduce_axis); } - static TensorIntrinCall make(TensorIntrin intrin, Array tensors, Array regions, diff --git a/python/tvm/api.py b/python/tvm/api.py index f1a96c14d61b..793afe52e5fd 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -270,9 +270,9 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None): dim_var, body.reduce_axis, out_ndim, + body.intrin, body.tensors, - body.regions, - body.intrin) + body.regions) else: if not isinstance(body, (list, tuple)): body = [body] diff --git a/src/op/tensor_compute_op.cc b/src/op/tensor_compute_op.cc index 345f3c97cbc4..f9b8188d4685 100644 --- a/src/op/tensor_compute_op.cc +++ b/src/op/tensor_compute_op.cc @@ -49,38 +49,23 @@ Array TensorComputeOpNode::output_shape(size_t i) const { } -// Operation TensorComputeOpNode::make(std::string name, -// std::string tag, -// Array out_axis, -// Array tensor_axis, -// TensorIntrinCall intrin_call) { -// return TensorComputeOpNode::make(name, -// tag, -// out_axis, -// tensor_axis, -// intrin_call->reduce_axis, -// intrin_call->tensors, -// intrin_call->regions, -// intrin_call->intrin); -// } - Operation TensorComputeOpNode::make(std::string name, std::string tag, Array axis, Array reduce_axis, - int sch_ndim, + int schedulable_ndim, + TensorIntrin intrin, Array tensors, - Array regions, - TensorIntrin intrin) { + Array regions) { auto n = make_node(); n->name = std::move(name); n->tag = std::move(tag); n->axis = std::move(axis); n->reduce_axis = std::move(reduce_axis); - n->sch_ndim = sch_ndim; + n->schedulable_ndim = std::move(schedulable_ndim); + n->intrin = std::move(intrin); n->inputs = std::move(tensors); n->input_regions = std::move(regions); - n->intrin = std::move(intrin); return Operation(n); } @@ -146,7 +131,6 @@ void TensorComputeOpNode::GatherBound( CHECK(!out_dom_map->count(this->axis[i])); (*out_dom_map)[this->axis[i]] = r; } - // should I add dom of tensor_vars for (size_t i = 0; i < this->reduce_axis.size(); ++i) { CHECK(!out_dom_map->count(this->reduce_axis[i])); (*out_dom_map)[this->reduce_axis[i]] = this->reduce_axis[i]->dom; @@ -168,7 +152,7 @@ Stmt TensorComputeOpNode::BuildRealize( realize = ir::Realize::make(t->op, t->value_index, t->dtype, bounds, const_true(), realize); // alignment requirement, only useful for compute - for (int i = 0; i < sch_ndim; ++i) { + for (int i = 0; i < schedulable_ndim; ++i) { auto it = stage->iter_var_attrs.find(this->axis[i]); if (it != stage->iter_var_attrs.end()) { IterVarAttr attr = (*it).second; @@ -214,7 +198,7 @@ ComputeLoopNest MakeLoopNest( for (IterVar iv : self->reduce_axis) { update_state[iv] = 2; } - for (int i = 0; i < self->sch_ndim; ++i) { + for (int i = 0; i < self->schedulable_ndim; ++i) { update_state[self->axis[i]] = 1; } // find which iter var is related to reduction and which is related to axis. @@ -292,7 +276,7 @@ Stmt TensorComputeOpNode::BuildProvide( Array tuple; for (size_t i = 0; i < this->axis.size(); ++i) { auto ivar = this->axis[i]; - if (i < static_cast(this->sch_ndim)) { + if (i < static_cast(this->schedulable_ndim)) { tuple.push_back(ivar->var); tuple.push_back(1); } else { diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 1ebc8d1f3b7e..ccf7fd617194 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -340,7 +340,7 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates); - for (int i = tensor_op->sch_ndim; i < static_cast(tensor_op->axis.size()); ++i) { + for (int i = tensor_op->schedulable_ndim; i < static_cast(tensor_op->axis.size()); ++i) { IterVar iv = tensor_op->axis[i]; IterVar new_iv = IterVarNode::make( iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type); @@ -359,12 +359,12 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, Operation cache_op = TensorComputeOpNode::make( tensor_op->name + "." + scope, tensor_op->tag, new_axis, - tensor_op->reduce_axis, tensor_op->sch_ndim, - tensor_op->inputs, new_regions, tensor_op->intrin); + tensor_op->reduce_axis, tensor_op->schedulable_ndim, + tensor_op->intrin, tensor_op->inputs, new_regions); // axis will be used in generating compute op Array compute_axis = tensor_op->axis; - for (size_t i = tensor_op->sch_ndim; i < tensor_op->axis.size(); ++i) { + for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) { IterVar iv = tensor_op->axis[i]; IterVar aiv = IterVarNode::make(iv->dom, iv->var, kDataPar); compute_axis.Set(i, aiv); @@ -384,7 +384,7 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, args.push_back(value_map.at(iv)); } // tensorized region axis - for (size_t i = tensor_op->sch_ndim; i < tensor_op->axis.size(); ++i) { + for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) { IterVar iv = compute_axis[i]; args.push_back(value_map.at(iv)); } diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py index 6d1515f1219f..2f49b084b875 100644 --- a/tests/python/unittest/test_lang_tensor.py +++ b/tests/python/unittest/test_lang_tensor.py @@ -111,7 +111,6 @@ def intrin_func(ins, outs): s = tvm.create_schedule(C.op) stmt = tvm.lower(s, [A, B, C], simple_mode=True) - print(stmt) assert isinstance(stmt.body.body, tvm.stmt.Evaluate) def test_tensor_compute2(): @@ -155,7 +154,6 @@ def intrin_func(ins, outs): s = tvm.create_schedule(C.op) stmt = tvm.lower(s, [A, B, C], simple_mode=True) - print(stmt) assert isinstance(stmt.body.body.body.first, tvm.stmt.Evaluate) assert isinstance(stmt.body.body.body.rest.body, tvm.stmt.Evaluate) diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index 9bd188ece68d..8774514cfa17 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -326,7 +326,6 @@ def test_schedule_tensor_compute1(): s[C].reorder(aio, aj, aii) aioo, ajo, aioi, aji = s[C].tile(aio, aj, 16, 4) - print(tvm.lower(s, [A, B, C], simple_mode=True)) s = s.normalize() bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) @@ -380,7 +379,6 @@ def test_schedule_tensor_compute2(): AL = s.cache_read(A, scope_ubuf, C) BL = s.cache_read(B, scope_ubuf, C) CL = s.cache_write(C, scope_ubuf) - print(tvm.lower(s, [A, B, C], simple_mode=True)) s = s.normalize() bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) @@ -400,7 +398,6 @@ def test_schedule_tensor_compute3(): lambda i: vadd(A[i, 0:factor], Bi[i, 0:factor]), name='C') s = tvm.create_schedule(C.op) s[Bi].compute_at(s[C], C.op.axis[0]) - print(tvm.lower(s, [A, B, C], simple_mode=True)) s = s.normalize() bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds)