diff --git a/include/tvm/operation.h b/include/tvm/operation.h index f87083303903..48591d316422 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -186,6 +186,8 @@ class TensorComputeOpNode : public OperationNode { public: Array axis; + Array out_axis; + Array tensor_axis; Array reduce_axis; @@ -229,6 +231,7 @@ class TensorComputeOpNode : public OperationNode { v->Visit("name", &name); v->Visit("tag", &tag); v->Visit("axis", &axis); + v->Visit("out_axis", &out_axis); v->Visit("tensor_axis", &tensor_axis); v->Visit("reduce_axis", &reduce_axis); v->Visit("inputs", &inputs); @@ -236,13 +239,13 @@ class TensorComputeOpNode : public OperationNode { static Operation make(std::string name, std::string tag, - Array axis, + Array out_axis, Array tensor_axis, TensorIntrinCall intrin_call); static Operation make(std::string name, std::string tag, - Array axis, + Array out_axis, Array tensor_axis, Array reduce_axis, Array tensors, diff --git a/include/tvm/tensor_intrin.h b/include/tvm/tensor_intrin.h index 6e2db89496d1..f70735b02264 100644 --- a/include/tvm/tensor_intrin.h +++ b/include/tvm/tensor_intrin.h @@ -26,14 +26,6 @@ class TensorIntrin : public NodeRef { */ inline const TensorIntrinNode* operator->() const; - // template - // inline Stmt operator()(Args&& ...args) const { - // Array inputs{std::forward(args)...}; - // return operator()(inputs); - // } - - // TVM_DLL TensorIntrinCall operator()(Array inputs) const; - /*! \brief specify container node */ using ContainerType = TensorIntrinNode; }; diff --git a/python/tvm/api.py b/python/tvm/api.py index 8f5068e06cfa..cc1792fa511d 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -3,7 +3,6 @@ from __future__ import absolute_import as _abs from numbers import Integral as _Integral -from collections import namedtuple from ._ffi.base import string_types from ._ffi.node import register_node, NodeBase @@ -244,6 +243,8 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None): raise ValueError("nested tag is not allowed for now") tag = _tag.TagScope.current.tag shape = (shape,) if isinstance(shape, _expr.Expr) else shape + # for python3 + shape = tuple([int(s) if isinstance(s, float) else s for s in shape]) ndim = len(shape) code = fcompute.__code__ @@ -254,7 +255,6 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None): arg_names = code.co_varnames[:code.co_argcount] out_ndim = code.co_argcount - # TODO check ndim, arg_names if out_ndim != len(arg_names): raise ValueError("fcompute do not match dimension, ndim=%d" % ndim) @@ -264,8 +264,8 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None): if isinstance(body, _tensor.TensorIntrinCall): tensor_var = [] for i, s in enumerate(shape[out_ndim:]): - name = "ax" + str(i) - tensor_var.append(_IterVar((0, s), name, 4)) + var_name = "ax" + str(i) + tensor_var.append(_IterVar((0, s), var_name, 4)) op_node = _api_internal._TensorComputeOp(name, tag, dim_var, @@ -275,7 +275,6 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None): if not isinstance(body, (list, tuple)): body = [body] body = convert(body) - # print('body: {0}'.format(body)) op_node = _api_internal._ComputeOp( name, tag, attrs, dim_var, body) @@ -353,88 +352,6 @@ def scan(init, update, state_placeholder, inputs=None, name="scan", tag="", attr return res[0] if len(res) == 1 else res -def _get_region(tslice): - region = [] - for idx in tslice.indices: - if isinstance(idx, slice): - assert idx.step is None - region.append(Range(idx.start, idx.stop)) - else: - if isinstance(idx, _schedule.IterVar): - begin = idx.var - else: - begin = idx - region.append(_make.range_by_min_extent(begin, 1)) - return region - - -# def tensor_op(out_dims, -# in_dims, # pylint: disable=unused-argument -# finputs, -# intrin, -# raxis=None, -# name='tensor_op', -# tag=""): -# """Construct new tensors with intrinsic. -# -# Parameters -# ---------- -# out_dims: tuple -# The dimensions out of the tensorized region, which can be -# scheduled through `reorder`, `split`. -# -# in_dims: tuple -# The dimensions inside of the tensorized region, which cannot -# be manipulated. -# -# finputs: lambda function of out_dims -> list of TensorSlice -# Specifies involved regions of input tensors. -# -# tensor_intrin : TensorIntrin -# The tensor intrinsic used for computation. -# -# raxis : IterVar -# An iteration variable representing the value. -# -# name: str, optional -# The name hint of the tensor -# -# tag: str, optional -# Additonal tag information about the compute. -# """ -# if _tag.TagScope.current is not None: -# if tag != "": -# raise ValueError("nested tag is not allowed for now") -# tag = _tag.TagScope.current.tag -# -# code = finputs.__code__ -# if finputs.__code__.co_argcount == 0: -# arg_names = ["i%d" % i for i in range(ndim)] -# else: -# arg_names = code.co_varnames[:code.co_argcount] -# -# if len(out_dims) != len(arg_names): -# raise ValueError("finputs do not match dimension, ndim=%d" % out_dims) -# -# out_var = [_IterVar((0, extent), arg_name, 0) -# for arg_name, extent in zip(arg_names, out_dims)] -# if isinstance(raxis, _schedule.IterVar): -# raxis = [raxis] -# if raxis is None: -# raxis = [] -# tensor_regions = finputs(*[v.var for v in out_var]) -# -# op = _api_internal._TensorOp(name, -# tag, -# out_var, -# raxis, -# [x.tensor for x in tensor_regions], -# [_get_region(x) for x in tensor_regions], -# intrin) -# # only support single output -# return op.output(0) - - def extern(shape, inputs, fcompute, diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py index b20a68ac0800..f32b70eb9a12 100644 --- a/python/tvm/tensor.py +++ b/python/tvm/tensor.py @@ -161,6 +161,12 @@ def reduce_axis(self): return self.__getattr__("reduce_axis") +@register_node +class TensorComputeOp(Operation): + """Tensor operation.""" + pass + + @register_node class ScanOp(Operation): """Scan operation.""" @@ -174,9 +180,3 @@ def scan_axis(self): class ExternOp(Operation): """Extern operation.""" pass - - -@register_node -class TensorComputeOp(Operation): - """Tensor operation.""" - pass diff --git a/python/tvm/tensor_intrin.py b/python/tvm/tensor_intrin.py index 818c487d0f35..04c24f21eb40 100644 --- a/python/tvm/tensor_intrin.py +++ b/python/tvm/tensor_intrin.py @@ -6,9 +6,25 @@ from . import stmt as _stmt from . import make as _make from . import tensor as _tensor +from . import schedule as _schedule from .build_module import current_build_config from ._ffi.node import NodeBase, register_node + +def _get_region(tslice): + region = [] + for idx in tslice.indices: + if isinstance(idx, slice): + assert idx.step is None + region.append(_api.Range(idx.start, idx.stop)) + else: + if isinstance(idx, _schedule.IterVar): + begin = idx.var + else: + begin = idx + region.append(_make.range_by_min_extent(begin, 1)) + return region + @register_node class TensorIntrin(NodeBase): """Tensor intrinsic functions for certain computation. @@ -19,7 +35,7 @@ class TensorIntrin(NodeBase): """ def __call__(self, *args, **kwargs): tensors = [x.tensor for x in args] - regions = [_api._get_region(x) for x in args] + regions = [_get_region(x) for x in args] reduce_axis = [] if "reduce_axis" in kwargs: reduce_axis = kwargs["reduce_axis"] diff --git a/src/lang/tensor.cc b/src/lang/tensor.cc index 9544de33a222..9b1a58abcee4 100644 --- a/src/lang/tensor.cc +++ b/src/lang/tensor.cc @@ -37,12 +37,6 @@ Tensor Operation::output(size_t i) const { return Tensor(node); } -// TensorIntrinCall TensorIntrin::operator()(Array inputs) const { -// using HalideIR::Internal::Call; -// LOG(FATAL) << "CallTensorIntrin"; -// CHECK_EQ(tensors.size(), regions.size()); -// } - Tensor TensorNode::make(Array shape, Type dtype, Operation op, diff --git a/src/op/tensor_op.cc b/src/op/tensor_compute_op.cc similarity index 85% rename from src/op/tensor_op.cc rename to src/op/tensor_compute_op.cc index fd9122c8b092..06f9972fd83c 100644 --- a/src/op/tensor_op.cc +++ b/src/op/tensor_compute_op.cc @@ -1,7 +1,7 @@ /*! * Copyright (c) 2017 by Contributors - * \brief Compute Op. - * \file compute_op.cc + * \brief Tensor Compute Op. + * \file tensor_compute_op.cc */ #include #include @@ -29,7 +29,7 @@ int TensorComputeOpNode::num_outputs() const { } Array TensorComputeOpNode::root_iter_vars() const { - Array ret = axis; + Array ret = out_axis; for (IterVar iv : tensor_axis) { ret.push_back(iv); } @@ -45,7 +45,7 @@ Type TensorComputeOpNode::output_dtype(size_t i) const { Array TensorComputeOpNode::output_shape(size_t i) const { Array shape; - for (const auto& ivar : this->axis) { + for (const auto& ivar : this->out_axis) { shape.push_back(ivar->dom->extent); } size_t index = this->inputs.size() + i; @@ -58,12 +58,12 @@ Array TensorComputeOpNode::output_shape(size_t i) const { Operation TensorComputeOpNode::make(std::string name, std::string tag, - Array axis, + Array out_axis, Array tensor_axis, TensorIntrinCall intrin_call) { return TensorComputeOpNode::make(name, tag, - axis, + out_axis, tensor_axis, intrin_call->reduce_axis, intrin_call->tensors, @@ -73,7 +73,7 @@ Operation TensorComputeOpNode::make(std::string name, Operation TensorComputeOpNode::make(std::string name, std::string tag, - Array axis, + Array out_axis, Array tensor_axis, Array reduce_axis, Array tensors, @@ -82,7 +82,15 @@ Operation TensorComputeOpNode::make(std::string name, auto n = make_node(); n->name = name; n->tag = tag; + Array axis; + for (auto iv : out_axis) { + axis.push_back(iv); + } + for (auto iv : tensor_axis) { + axis.push_back(iv); + } n->axis = axis; + n->out_axis = out_axis; n->tensor_axis = tensor_axis; n->reduce_axis = reduce_axis; n->inputs = tensors; @@ -148,10 +156,10 @@ void TensorComputeOpNode::GatherBound( const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const { const TensorDom& tdom = tensor_dom.at(self.output(0)); - for (size_t i = 0; i < this->axis.size(); ++i) { - Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom); - CHECK(!out_dom_map->count(this->axis[i])); - (*out_dom_map)[this->axis[i]] = r; + for (size_t i = 0; i < this->out_axis.size(); ++i) { + Range r = arith::Union(tdom.data.at(i)).cover_range(this->out_axis[i]->dom); + CHECK(!out_dom_map->count(this->out_axis[i])); + (*out_dom_map)[this->out_axis[i]] = r; } for (size_t i = 0; i < this->reduce_axis.size(); ++i) { CHECK(!out_dom_map->count(this->reduce_axis[i])); @@ -165,7 +173,7 @@ Stmt TensorComputeOpNode::BuildRealize( const Stmt& body) const { CHECK_EQ(stage->op.get(), this); HalideIR::Internal::Region bounds; - for (IterVar iv : this->axis) { + for (IterVar iv : this->out_axis) { bounds.push_back(realize_map.at(iv)); } size_t out_buff_idx = this->intrin->buffers.size(); @@ -178,8 +186,8 @@ Stmt TensorComputeOpNode::BuildRealize( realize = ir::Realize::make(t->op, t->value_index, t->dtype, bounds, const_true(), realize); // alignment requirement, only useful for compute - for (size_t i = 0; i < this->axis.size(); ++i) { - auto it = stage->iter_var_attrs.find(this->axis[i]); + for (size_t i = 0; i < this->out_axis.size(); ++i) { + auto it = stage->iter_var_attrs.find(this->out_axis[i]); if (it != stage->iter_var_attrs.end()) { IterVarAttr attr = (*it).second; if (attr->dim_align_factor != 0) { @@ -224,7 +232,7 @@ ComputeLoopNest MakeLoopNest( for (IterVar iv : self->reduce_axis) { update_state[iv] = 2; } - for (IterVar iv : self->axis) { + for (IterVar iv : self->out_axis) { update_state[iv] = 1; } // find which iter var is related to reduction and which is related to axis. @@ -264,21 +272,23 @@ ComputeLoopNest MakeLoopNest( } -Stmt MakeTensorComputeOp(const TensorComputeOpNode* self, - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) { +Stmt TensorComputeOpNode::BuildProvide( + const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const { + CHECK_EQ(stage->op.operator->(), this); + // Start bind data. Stmt nop = Evaluate::make(0); std::vector input_bind_nest, output_bind_nest; - Array inputs = self->InputTensors(); + Array inputs = this->InputTensors(); // input binding size_t num_inputs = inputs.size(); for (size_t i = 0; i < num_inputs; ++i) { Tensor tensor = inputs[i]; - Region region = self->input_regions[i]; - Buffer buffer = self->intrin->buffers[i]; + Region region = this->input_regions[i]; + Buffer buffer = this->intrin->buffers[i]; Array bind_spec{buffer, tensor}; Array tuple; @@ -292,13 +302,13 @@ Stmt MakeTensorComputeOp(const TensorComputeOpNode* self, } // output binding - for (int i = 0; i < self->num_outputs(); ++i) { + for (int i = 0; i < this->num_outputs(); ++i) { Tensor tensor = stage->op.output(i); - Buffer buffer = self->intrin->buffers[num_inputs + i]; + Buffer buffer = this->intrin->buffers[num_inputs + i]; Array bind_spec{buffer, tensor}; Array tuple; - for (const IterVar ivar : self->axis) { + for (const IterVar ivar : this->out_axis) { tuple.push_back(ivar->var); tuple.push_back(1); } @@ -317,16 +327,16 @@ Stmt MakeTensorComputeOp(const TensorComputeOpNode* self, ir::ArgBinder binder(&vmap); size_t tloc = stage->leaf_iter_vars.size(); - ComputeLoopNest n = MakeLoopNest(self, stage, dom_map, debug_keep_trivial_loop); + ComputeLoopNest n = MakeLoopNest(this, stage, dom_map, debug_keep_trivial_loop); - if (self->reduce_axis.size() == 0) { + if (this->reduce_axis.size() == 0) { std::vector > nest( n.main_nest.begin(), n.main_nest.begin() + tloc + 1); nest.emplace_back(op::MakeIfNest(n.main_predicates)); CHECK_EQ(n.init_predicates.size(), 0U); - CHECK(self->intrin->body.defined()) - << "Normal store op for intrin " << self << " is not defined"; - Stmt body = MergeNest(output_bind_nest, self->intrin->body); + CHECK(this->intrin->body.defined()) + << "Normal store op for intrin " << this << " is not defined"; + Stmt body = MergeNest(output_bind_nest, this->intrin->body); body = MergeNest(input_bind_nest, body); body = ir::Substitute(body, vmap); body = MergeNest(binder.asserts(), body); @@ -335,26 +345,26 @@ Stmt MakeTensorComputeOp(const TensorComputeOpNode* self, return ret; } else { // Need to split reduction - CHECK(self->intrin->reduce_update.defined()) + CHECK(this->intrin->reduce_update.defined()) << "Reduction update op is not defined"; // Need init and update steps - CHECK_NE(self->reduce_axis.size(), 0U); + CHECK_NE(this->reduce_axis.size(), 0U); std::vector > common( n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1); std::vector > update_nest( n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1); update_nest.emplace_back(op::MakeIfNest(n.main_predicates)); - if (self->intrin->reduce_init.defined()) { + if (this->intrin->reduce_init.defined()) { // init nest std::vector > init_nest( n.init_nest.begin(), n.init_nest.begin() + tloc + 1); init_nest.emplace_back(op::MakeIfNest(n.init_predicates)); - Stmt init = MergeNest(output_bind_nest, self->intrin->reduce_init); + Stmt init = MergeNest(output_bind_nest, this->intrin->reduce_init); init = op::Substitute(init, n.init_vmap); init = MergeNest(init_nest, init); // The update - Stmt update = MergeNest(output_bind_nest, self->intrin->reduce_update); + Stmt update = MergeNest(output_bind_nest, this->intrin->reduce_update); update = MergeNest(input_bind_nest, update); update = ir::Substitute(update, vmap); update = MergeNest(binder.asserts(), update); @@ -363,11 +373,11 @@ Stmt MakeTensorComputeOp(const TensorComputeOpNode* self, return MergeNest(common, Block::make(init, update)); } else { // When init op is not available, use body op for reset in the first iter. - CHECK(self->intrin->body.defined()) + CHECK(this->intrin->body.defined()) << "Normal body op is not defined"; Stmt update = TransformUpdate(stage, dom_map, n, - self->intrin->body, - self->intrin->reduce_update); + this->intrin->body, + this->intrin->reduce_update); update = MergeNest(output_bind_nest, update); update = MergeNest(input_bind_nest, update); update = ir::Substitute(update, vmap); @@ -379,14 +389,4 @@ Stmt MakeTensorComputeOp(const TensorComputeOpNode* self, } } - -Stmt TensorComputeOpNode::BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const { - CHECK_EQ(stage->op.operator->(), this); - Stmt ret = MakeTensorComputeOp(this, stage, dom_map, debug_keep_trivial_loop); - return ret; -} - } // namespace tvm diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 02704fb56124..20c2228b780d 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -135,30 +135,29 @@ Tensor Schedule::cache_read(const Tensor& tensor, return cache; } - -// for tensor_op -Array CacheWriteWithReLayoutForTensorComputeOp(Schedule sch, - const Array& tensor_array, - const std::string& scope) { - size_t tensor_size = tensor_array.size(); - sch->InvalidateCache(); - Tensor tensor = tensor_array[0]; - Stage orig_stage = sch[tensor->op]; - const TensorComputeOpNode* tensor_op = orig_stage->op.as(); - std::unordered_set red_axis; - for (IterVar iv : tensor_op->reduce_axis) { +template +void PrepareAxisMapping(Stage orig_stage, + OpType* op, + std::unordered_set* p_red_axis, + Array* p_new_axis, + std::unordered_map* p_dom_map, + std::unordered_map* p_vsub, + std::unordered_map* p_vsub2newvar, + std::vector* p_predicates) { + auto& red_axis = *p_red_axis; + auto& new_axis = *p_new_axis; + auto& dom_map = *p_dom_map; + auto& vsub = *p_vsub; + auto& vsub2newvar = *p_vsub2newvar; + auto& predicates = *p_predicates; + + for (IterVar iv : op->reduce_axis) { red_axis.insert(iv); } - std::unordered_map dom_map; - Array new_axis; - - for (IterVar iv : tensor_op->root_iter_vars()) { + for (IterVar iv : op->axis) { dom_map[iv] = iv->dom; } schedule::PassDownDomain(orig_stage, &dom_map, true); - std::unordered_map vsub; - std::unordered_map vsub2newvar; - std::vector predicates; { // The source->cache std::unordered_map value_map; @@ -171,7 +170,7 @@ Array CacheWriteWithReLayoutForTensorComputeOp(Schedule sch, dom, iv->var.copy_with_suffix(".c"), iv->iter_type); new_axis.push_back(new_iv); if (is_one(dom->min)) { - value_map[iv] = dom->extent; + value_map[iv] = dom->min; } else { value_map[iv] = iv->var; vsub2newvar[iv->var.get()] = new_iv->var; @@ -179,75 +178,32 @@ Array CacheWriteWithReLayoutForTensorComputeOp(Schedule sch, } // skip reduction iteration. std::unordered_set skip_bound_check; - for (IterVar iv : tensor_op->reduce_axis) { + for (IterVar iv : op->reduce_axis) { skip_bound_check.insert(iv); } schedule::PassUpIndex(orig_stage, dom_map, &value_map, true); predicates = schedule::MakeBoundCheck( orig_stage, dom_map, value_map, true, skip_bound_check); // The root axis - for (IterVar iv : tensor_op->axis) { - vsub[iv->var.get()] = value_map.at(iv); - } - } - - Array new_regions; - for (Region old_region : tensor_op->input_regions) { - Region region; - for (Range r : old_region) { - Expr min = VarReplacer(vsub2newvar).Mutate(r->min); - Expr extent = VarReplacer(vsub2newvar).Mutate(r->extent); - region.push_back(Range::make_by_min_extent(min, extent)); - } - new_regions.push_back(region); - } - - Operation cache_op = TensorComputeOpNode::make( - tensor_op->name + "." + scope, tensor_op->tag, new_axis, - tensor_op->reduce_axis, tensor_op->tensor_axis, - tensor_op->inputs, new_regions, tensor_op->intrin); - - // axis will be used in generating compute op - Array compute_axis = tensor_op->axis; - for (IterVar iv : tensor_op->tensor_axis) { - // new tensor axis with kDataPar IterVar type - IterVar aiv = IterVarNode::make(iv->dom, iv->var, kDataPar); - compute_axis.push_back(aiv); - } - - // The reader args - Array args; - { - // cache->compute - std::unordered_map value_map; - for (IterVar iv : compute_axis) { - value_map[iv] = iv->var; - } - schedule::PassDownIndex(orig_stage, dom_map, &value_map, true); - for (IterVar iv : orig_stage->leaf_iter_vars) { - if (red_axis.count(iv)) continue; - args.push_back(value_map.at(iv)); - } - // tensorized region axis - for (size_t i = 0; i < tensor_op->tensor_axis.size(); ++i) { - IterVar iv = compute_axis[tensor_op->axis.size() + i]; - args.push_back(value_map.at(iv)); + for (IterVar iv : op->axis) { + if (value_map.count(iv)) { + vsub[iv->var.get()] = value_map.at(iv); + } // to handle tensor axis } } +} +Array ReplaceOriginalOp(Schedule sch, + Stage orig_stage, + const std::string& scope, + Operation cache_op, + Operation orig_new_op, + size_t tensor_size) { Array cache_tensor_list; - Array cache_expr_list; for (size_t i = 0; i < tensor_size; i++) { Tensor cache_tensor = cache_op.output(i); cache_tensor_list.push_back(cache_tensor); - cache_expr_list.push_back(cache_tensor(args)); } - - Operation orig_new_op = ComputeOpNode::make( - tensor_op->name, tensor_op->tag, {}, - compute_axis, cache_expr_list); - - // The replace of the dataflow std::unordered_map vmap; std::unordered_map rvmap; @@ -283,61 +239,24 @@ Array CacheWriteWithReLayoutForTensorComputeOp(Schedule sch, // Cache write and relayout the data according to loop pattern Array CacheWriteWithReLayout(Schedule sch, - const Array& tensor_array, - const std::string& scope) { + const Array& tensor_array, + const std::string& scope) { size_t tensor_size = tensor_array.size(); sch->InvalidateCache(); Tensor tensor = tensor_array[0]; Stage orig_stage = sch[tensor->op]; - if (!strcmp(orig_stage->op->type_key(), "TensorComputeOp")) { - return CacheWriteWithReLayoutForTensorComputeOp(sch, tensor_array, scope); - } const ComputeOpNode* compute = orig_stage->op.as(); + std::unordered_set red_axis; - for (IterVar iv : compute->reduce_axis) { - red_axis.insert(iv); - } - std::unordered_map dom_map; Array new_axis; + std::unordered_map dom_map; - for (IterVar iv : compute->axis) { - dom_map[iv] = iv->dom; - } - schedule::PassDownDomain(orig_stage, &dom_map, true); std::unordered_map vsub; std::unordered_map vsub2newvar; std::vector predicates; - { - // The source->cache - std::unordered_map value_map; - for (IterVar iv : orig_stage->leaf_iter_vars) { - if (red_axis.count(iv)) continue; - CHECK_EQ(iv->iter_type, kDataPar) - << "Can only relayout with in data parallel dimensions"; - Range dom = dom_map.at(iv); - IterVar new_iv = IterVarNode::make( - dom, iv->var.copy_with_suffix(".c"), iv->iter_type); - new_axis.push_back(new_iv); - if (is_one(dom->min)) { - value_map[iv] = dom->min; - } else { - value_map[iv] = iv->var; - vsub2newvar[iv->var.get()] = new_iv->var; - } - } - // skip reduction iteration. - std::unordered_set skip_bound_check; - for (IterVar iv : compute->reduce_axis) { - skip_bound_check.insert(iv); - } - schedule::PassUpIndex(orig_stage, dom_map, &value_map, true); - predicates = schedule::MakeBoundCheck( - orig_stage, dom_map, value_map, true, skip_bound_check); - // The root axis - for (IterVar iv : compute->axis) { - vsub[iv->var.get()] = value_map.at(iv); - } - } + + PrepareAxisMapping(orig_stage, compute, + &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates); Expr body; Array body_list; @@ -383,46 +302,98 @@ Array CacheWriteWithReLayout(Schedule sch, Operation cache_op = ComputeOpNode::make( compute->name + "." + scope, compute->tag, compute->attrs, new_axis, body_list); - Array cache_tensor_list; + Array cache_expr_list; for (size_t i = 0; i < tensor_size; i++) { Tensor cache_tensor = cache_op.output(i); - cache_tensor_list.push_back(cache_tensor); cache_expr_list.push_back(cache_tensor(args)); } Operation orig_new_op = ComputeOpNode::make( compute->name, compute->tag, compute->attrs, compute->axis, cache_expr_list); - // The replace of the dataflow - std::unordered_map vmap; - std::unordered_map rvmap; - vmap[orig_stage->op.output(0)] = orig_new_op.output(0); - rvmap[orig_new_op.output(0)] = orig_stage->op.output(0); - for (size_t i = 0; i < tensor_size; i++) { - vmap[orig_stage->op.output(0)] = orig_new_op.output(0); - rvmap[orig_new_op.output(0)] = orig_stage->op.output(0); + return ReplaceOriginalOp(sch, orig_stage, scope, + cache_op, orig_new_op, tensor_size); +} + + +// for tensor compute op +Array CacheWriteWithReLayoutTensor(Schedule sch, + const Array& tensor_array, + const std::string& scope) { + size_t tensor_size = tensor_array.size(); + sch->InvalidateCache(); + Tensor tensor = tensor_array[0]; + Stage orig_stage = sch[tensor->op]; + const TensorComputeOpNode* tensor_op = orig_stage->op.as(); + CHECK_EQ(tensor_op->num_outputs(), 1) + << "cache write only support single output tensor_compute_op"; + + std::unordered_set red_axis; + Array new_axis; + std::unordered_map dom_map; + + std::unordered_map vsub; + std::unordered_map vsub2newvar; + std::vector predicates; + + PrepareAxisMapping(orig_stage, tensor_op, + &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates); + + + Array new_regions; + for (Region old_region : tensor_op->input_regions) { + Region region; + for (Range r : old_region) { + Expr min = VarReplacer(vsub2newvar).Mutate(r->min); + Expr extent = VarReplacer(vsub2newvar).Mutate(r->extent); + region.push_back(Range::make_by_min_extent(min, extent)); + } + new_regions.push_back(region); } - ReplaceDataFlow(sch->stages, &vmap, &rvmap); - // mutate orig stage - orig_stage->op = orig_new_op; - orig_stage->all_iter_vars = orig_stage->op->root_iter_vars(); - orig_stage->leaf_iter_vars = orig_stage->all_iter_vars; - orig_stage->relations = Array(); - // create schedule for new cached stage. - ArrayNode* stages = sch->stages.CopyOnWrite(); - size_t pos = FindNodeRef(stages, orig_stage); - Stage cache_stage = Stage(cache_op); - cache_stage.set_scope(scope); - CHECK_LT(pos, stages->data.size()); - stages->data.insert(stages->data.begin() + pos, - cache_stage.node_); - sch->stage_map.Set(cache_op, cache_stage); - // Update group - cache_stage->group = orig_stage->group; - if (cache_stage->group.defined()) { - ++cache_stage->group->num_child_stages; + + Operation cache_op = TensorComputeOpNode::make( + tensor_op->name + "." + scope, tensor_op->tag, new_axis, + tensor_op->tensor_axis, tensor_op->reduce_axis, + tensor_op->inputs, new_regions, tensor_op->intrin); + + // axis will be used in generating compute op + Array compute_axis = tensor_op->out_axis; + for (IterVar iv : tensor_op->tensor_axis) { + // new tensor axis with kDataPar IterVar type + IterVar aiv = IterVarNode::make(iv->dom, iv->var, kDataPar); + compute_axis.push_back(aiv); } - return cache_tensor_list; + + // The reader args + Array args; + { + // cache->compute + std::unordered_map value_map; + for (IterVar iv : compute_axis) { + value_map[iv] = iv->var; + } + schedule::PassDownIndex(orig_stage, dom_map, &value_map, true); + for (IterVar iv : orig_stage->leaf_iter_vars) { + if (red_axis.count(iv)) continue; + args.push_back(value_map.at(iv)); + } + // tensorized region axis + for (size_t i = 0; i < tensor_op->tensor_axis.size(); ++i) { + IterVar iv = compute_axis[tensor_op->out_axis.size() + i]; + args.push_back(value_map.at(iv)); + } + } + + Array cache_expr_list; + for (size_t i = 0; i < tensor_size; i++) { + Tensor cache_tensor = cache_op.output(i); + cache_expr_list.push_back(cache_tensor(args)); + } + Operation orig_new_op = ComputeOpNode::make( + tensor_op->name, tensor_op->tag, {}, + compute_axis, cache_expr_list); + return ReplaceOriginalOp(sch, orig_stage, scope, + cache_op, orig_new_op, tensor_size); } @@ -441,23 +412,26 @@ Array Schedule::cache_write(const Array& tensor_array, CHECK(orig_stage.same_as(tmp_stage)) << "Input tensor list must be generated by ONE computeOp"; } - return CacheWriteWithReLayout(*this, tensor_array, scope); } + Tensor Schedule::cache_write(const Tensor& tensor, const std::string& scope) { + // support original compute and tensor compute both (*this)->InvalidateCache(); - Stage orig_stage = operator[](tensor->op); - // const ComputeOpNode* compute = tensor->op.as(); - // CHECK(compute) - // << "cache write only take ComputeOp as writers"; - // CHECK_EQ(compute->num_outputs(), 1) - // << "cache write only support single output ComputeOp"; - - return (CacheWriteWithReLayout(*this, {tensor}, scope))[0]; + const char* type_key = tensor->op->type_key(); + if (!strcmp(type_key, "ComputeOp")) { + return (CacheWriteWithReLayout(*this, {tensor}, scope))[0]; + } else if (!strcmp(type_key, "TensorComputeOp")) { + return (CacheWriteWithReLayoutTensor(*this, {tensor}, scope))[0]; + } else { + LOG(FATAL) << "cache write only take ComputeOp or TensorComputeOp as writers"; + return Tensor(); + } } + void RebaseNonZeroMinLoop(const Schedule& sch) { std::unordered_map rebase_map; for (Stage s : sch->stages) { diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index 8e6f4090d403..ef1babbec72d 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -276,6 +276,135 @@ def test_schedule_bound_condition(): stmt = tvm.ir_pass.Simplify(stmt) assert (isinstance(stmt.body.body.first.body.body.then_case, tvm.stmt.IfThenElse)) + +def intrin_gemv(m, n): + w = tvm.placeholder((m, n), name='w') + x = tvm.placeholder((n,), name='x') + k = tvm.reduce_axis((0, n), name='k') + z = tvm.compute((m,), lambda i: + tvm.sum(w[i, k] * x[k], axis=k), name='z') + Wb = tvm.decl_buffer(w.shape, w.dtype, + name="W", + offset_factor=16, + strides=[tvm.var('ldw'), 1]) + def intrin_func(ins, outs): + ww, xx = ins + zz = outs[0] + ww_ptr = ww.access_ptr("r") + xx_ptr = xx.access_ptr("r") + zz_ptr = zz.access_ptr("w") + body = tvm.call_packed( + "gemm", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) + reset = tvm.call_packed( + "fill_zero", zz_ptr, n) + update = tvm.call_packed( + "gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) + return body, reset, update + + with tvm.build_config(data_alignment=16, + offset_factor=16): + return tvm.decl_tensor_intrin(z.op, intrin_func, + binds={w: Wb}) + + +def test_schedule_tensor_compute1(): + # basic: split, reorder, tile + M, N, L = 2048, 1024, 512 + factor, rfactor = 16, 16 + A = tvm.placeholder((N/factor, L/rfactor, factor, rfactor), name='A') + B = tvm.placeholder((M, L/rfactor, rfactor), name='B') + k = tvm.reduce_axis((0, L/rfactor), name='k') + + gemv = intrin_gemv(factor, rfactor) + C = tvm.compute((N, M/factor, factor), + lambda i, j: gemv(A[i, k, 0:factor, 0:factor], B[j, k, 0:rfactor], reduce_axis=k), + name='C') + + s = tvm.create_schedule(C.op) + ai, aj, ax = s[C].op.axis + aio, aii = s[C].split(ai, 16) + s[C].reorder(aio, aj, aii) + aioo, ajo, aioi, aji = s[C].tile(aio, aj, 16, 4) + + print(tvm.lower(s, [A, B, C], simple_mode=True)) + s = s.normalize() + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + + +def intrin_vadd(n, cache_read=False, cache_write=False): + scope_ubuf = 'local' + dtype = 'float32' + x = tvm.placeholder((n,), dtype=dtype, name='vx') + y = tvm.placeholder((n,), dtype=dtype, name='vy') + z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z') + s = tvm.create_schedule(z.op) + + def create_buffer(t): + return tvm.decl_buffer(t.shape, t.dtype, + name='W'+t.name, + scope=scope_ubuf, + offset_factor=16) + + binds = {} + if cache_read: + binds[x] = create_buffer(x) + binds[y] = create_buffer(y) + if cache_write: + binds[z] = create_buffer(z) + + def intrin_func(ins, outs): + ib = tvm.ir_builder.create() + ib.emit(tvm.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr'))) + return ib.get() + + with tvm.build_config(offset_factor=16): + return tvm.decl_tensor_intrin(z.op, intrin_func, binds=binds) + + +def test_schedule_tensor_compute2(): + # cache_read, cache_write + M = 1024 + factor = 16 + dtype = 'float32' + scope_ubuf = 'local' + + A = tvm.placeholder((M/factor, factor), name="A", dtype=dtype) + B = tvm.placeholder((M/factor, factor), name="B", dtype=dtype) + + vadd = intrin_vadd(factor, True, True) + C = tvm.compute((M/factor, factor), + lambda i: vadd(A[i, 0:factor], B[i, 0:factor]), name='C') + + s = tvm.create_schedule(C.op) + AL = s.cache_read(A, scope_ubuf, C) + BL = s.cache_read(B, scope_ubuf, C) + CL = s.cache_write(C, scope_ubuf) + print(tvm.lower(s, [A, B, C], simple_mode=True)) + s = s.normalize() + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + + +def test_schedule_tensor_compute3(): + M = 1024 + factor = 16 + dtype = 'float32' + A = tvm.placeholder((M/factor, factor), name="A", dtype=dtype) + B = tvm.placeholder((M/factor, factor), name="B", dtype=dtype) + Bi = tvm.compute((M/factor, factor), lambda i, j: B[i, j] + 5, name="Bi") + + vadd = intrin_vadd(factor) + C = tvm.compute((M/factor, factor), + lambda i: vadd(A[i, 0:factor], Bi[i, 0:factor]), name='C') + s = tvm.create_schedule(C.op) + s[Bi].compute_at(s[C], C.op.axis[0]) + print(tvm.lower(s, [A, B, C], simple_mode=True)) + s = s.normalize() + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + + if __name__ == "__main__": test_schedule_middle_cache() test_inline_multi_reduce() @@ -294,3 +423,6 @@ def test_schedule_bound_condition(): test_schedule2() test_schedule_cache() test_schedule_bound_condition() + test_schedule_tensor_compute1() + test_schedule_tensor_compute2() + test_schedule_tensor_compute3()