From def81252be3feee13af529e6f122ec9ae906e84d Mon Sep 17 00:00:00 2001 From: Jessica Davies Date: Fri, 15 Feb 2019 17:53:37 -0500 Subject: [PATCH] Extend TensorComputeOp to allow scalar inputs (#2606). --- include/tvm/operation.h | 6 +++- include/tvm/tensor_intrin.h | 16 ++++++++++- python/tvm/api.py | 3 +- python/tvm/tensor_intrin.py | 22 ++++++++++---- src/lang/tensor.cc | 6 +++- src/op/tensor_compute_op.cc | 17 ++++++++++- src/schedule/schedule_dataflow_rewrite.cc | 7 ++++- tests/python/unittest/test_lang_schedule.py | 32 +++++++++++++++++++++ 8 files changed, 97 insertions(+), 12 deletions(-) diff --git a/include/tvm/operation.h b/include/tvm/operation.h index 15a8c1215177..38dc39bbe7a7 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -286,6 +286,8 @@ class TensorComputeOpNode : public BaseComputeOpNode { Array inputs; /*! \brief region of input tensors */ Array input_regions; + /*! \brief scalar expression inputs */ + Array scalar_inputs; /*! \brief constructor */ TensorComputeOpNode() {} // override functions @@ -314,6 +316,7 @@ class TensorComputeOpNode : public BaseComputeOpNode { v->Visit("intrin", &intrin); v->Visit("inputs", &inputs); v->Visit("input_regions", &input_regions); + v->Visit("scalar_inputs", &scalar_inputs); } static Operation make(std::string name, std::string tag, @@ -322,7 +325,8 @@ class TensorComputeOpNode : public BaseComputeOpNode { int schedulable_ndim, TensorIntrin intrin, Array tensors, - Array regions); + Array regions, + Array scalar_inputs); static constexpr const char* _type_key = "TensorComputeOp"; TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, BaseComputeOpNode); diff --git a/include/tvm/tensor_intrin.h b/include/tvm/tensor_intrin.h index e61ce6634bd3..b5ca6eb4358b 100644 --- a/include/tvm/tensor_intrin.h +++ b/include/tvm/tensor_intrin.h @@ -67,6 +67,11 @@ class TensorIntrinNode : public Node { * When it is a constant, it means we can only take data in that shape. */ Array buffers; + /*! \brief List of scalar variables, used in body. These placeholders + * will be bound to expressions passed in when the TensorIntrin is called + * from a TensorComputeOp. + */ + Array scalar_params; /*! \brief The normal statement to execute the intrinsic */ Stmt body; /*! @@ -87,6 +92,7 @@ class TensorIntrinNode : public Node { v->Visit("op", &op); v->Visit("inputs", &inputs); v->Visit("buffers", &buffers); + v->Visit("scalar_params", &scalar_params); v->Visit("body", &body); v->Visit("reduce_init", &reduce_init); v->Visit("reduce_update", &reduce_update); @@ -96,6 +102,7 @@ class TensorIntrinNode : public Node { Operation op, Array inputs, Array buffers, + Array scalar_params, Stmt body, Stmt reduce_init, Stmt reduce_update); @@ -134,22 +141,29 @@ class TensorIntrinCallNode : public Node { Array tensors; /*! \brief regions of input tensors */ Array regions; + + /*! * \brief IterVar on each reduction axis, if the * intrin will use the reduce axis */ Array reduce_axis; + /*! \brief scalar expression inputs */ + Array scalar_inputs; + void VisitAttrs(AttrVisitor* v) final { v->Visit("intrin", &intrin); v->Visit("tensors", &tensors); v->Visit("regions", ®ions); v->Visit("reduce_axis", &reduce_axis); + v->Visit("scalar_inputs", &scalar_inputs); } static TensorIntrinCall make(TensorIntrin intrin, Array tensors, Array regions, - Array reduce_axis); + Array reduce_axis, + Array scalar_inputs); static constexpr const char* _type_key = "TensorIntrinCall"; TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinCallNode, Node); diff --git a/python/tvm/api.py b/python/tvm/api.py index 66fa4fa30e90..d88f06170543 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -319,7 +319,8 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None): out_ndim, body.intrin, body.tensors, - body.regions) + body.regions, + body.scalar_inputs) else: if not isinstance(body, (list, tuple)): body = [body] diff --git a/python/tvm/tensor_intrin.py b/python/tvm/tensor_intrin.py index f97e6b7579a1..2ef7a4bbb23e 100644 --- a/python/tvm/tensor_intrin.py +++ b/python/tvm/tensor_intrin.py @@ -50,20 +50,23 @@ class TensorIntrin(NodeBase): decl_tensor_intrin: Construct a TensorIntrin """ def __call__(self, *args, **kwargs): - tensors = [x.tensor for x in args] - regions = [_get_region(x) for x in args] + tensors = [x.tensor for x in args if isinstance(x, _tensor.TensorSlice)] + scalar_inputs = [x for x in args if not isinstance(x, _tensor.TensorSlice)] + regions = [_get_region(x) for x in args if isinstance(x, _tensor.TensorSlice)] reduce_axis = [] if "reduce_axis" in kwargs: reduce_axis = kwargs["reduce_axis"] if not isinstance(reduce_axis, (list, tuple)): reduce_axis = [reduce_axis] reduce_axis = _api.convert(reduce_axis) - return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis) + if scalar_inputs: + scalar_inputs = _api.convert(scalar_inputs) + return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis, scalar_inputs) def decl_tensor_intrin(op, fcompute, name="tensor_intrin", - binds=None): + binds=None, scalar_params=None): """Declare a tensor intrinsic function. Parameters @@ -96,6 +99,9 @@ def decl_tensor_intrin(op, requirement of the function. By default, a new compact buffer is created for each tensor in the argument. + scalar_params: a list of variables used by op, whose values will be passed + as scalar_inputs when the tensor intrinsic is called. + Returns ------- intrin: TensorIntrin @@ -122,11 +128,15 @@ def decl_tensor_intrin(op, offset_factor=cfg.offset_factor)) binds_list.append(buf) - body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):]) + if scalar_params: + body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):], scalar_params) + else: + body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):]) + scalar_params = [] if isinstance(body, (_expr.Expr, _stmt.Stmt)): body = [body] body = [_make.Evaluate(x) if isinstance(x, _expr.Expr) else x for x in body] if len(body) < 3: body += [None] * (3 - len(body)) return _api_internal._TensorIntrin( - name, op, inputs, binds_list, *body) + name, op, inputs, binds_list, scalar_params, *body) diff --git a/src/lang/tensor.cc b/src/lang/tensor.cc index bab7cf6d93ed..d885d7103606 100644 --- a/src/lang/tensor.cc +++ b/src/lang/tensor.cc @@ -83,6 +83,7 @@ TensorIntrin TensorIntrinNode::make(std::string name, Operation op, Array inputs, Array buffers, + Array scalar_params, Stmt body, Stmt reduce_init, Stmt reduce_update) { @@ -91,6 +92,7 @@ TensorIntrin TensorIntrinNode::make(std::string name, n->op = std::move(op); n->inputs = std::move(inputs); n->buffers = std::move(buffers); + n->scalar_params = std::move(scalar_params); n->body = std::move(body); n->reduce_init = std::move(reduce_init); n->reduce_update = std::move(reduce_update); @@ -110,12 +112,14 @@ TVM_REGISTER_NODE_TYPE(TensorIntrinNode); TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin, Array tensors, Array regions, - Array reduce_axis) { + Array reduce_axis, + Array scalar_inputs) { auto n = make_node(); n->intrin = std::move(intrin); n->tensors = std::move(tensors); n->regions = std::move(regions); n->reduce_axis = std::move(reduce_axis); + n->scalar_inputs = std::move(scalar_inputs); return TensorIntrinCall(n); } diff --git a/src/op/tensor_compute_op.cc b/src/op/tensor_compute_op.cc index ed768c2ba216..09e8af7d5cba 100644 --- a/src/op/tensor_compute_op.cc +++ b/src/op/tensor_compute_op.cc @@ -58,7 +58,8 @@ Operation TensorComputeOpNode::make(std::string name, int schedulable_ndim, TensorIntrin intrin, Array tensors, - Array regions) { + Array regions, + Array scalar_inputs) { auto n = make_node(); n->name = std::move(name); n->tag = std::move(tag); @@ -68,6 +69,7 @@ Operation TensorComputeOpNode::make(std::string name, n->intrin = std::move(intrin); n->inputs = std::move(tensors); n->input_regions = std::move(regions); + n->scalar_inputs = std::move(scalar_inputs); return Operation(n); } @@ -184,6 +186,19 @@ Stmt TensorComputeOpNode::BuildProvide( std::unordered_map vmap; ir::ArgBinder binder(&vmap); + // Map the expressions passed in the call to the TensorIntrin, to the placeholder + // variables + Array user_expr = this->scalar_inputs; + Array scalar_params = this->intrin->scalar_params; + Array sp_expr; + for (auto sp : scalar_params) { + Expr esp = sp; + sp_expr.push_back(esp); + } + CHECK_EQ(sp_expr.size(), user_expr.size()); + // TODO(jdavies-huawei): what name should be used here? + binder.BindArray(sp_expr, user_expr, this->name); + size_t tloc = stage->leaf_iter_vars.size(); ComputeLoopNest n = ComputeLoopNest::make(this, stage, dom_map, debug_keep_trivial_loop); diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index cbb5ae3df0d6..bb42754ac820 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -410,10 +410,15 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, new_regions.push_back(region); } + Array new_scalar_inputs; + for (Expr old_input : tensor_op->scalar_inputs) { + new_scalar_inputs.push_back(VarReplacer(vsub2newvar).Mutate(old_input)); + } + Operation cache_op = TensorComputeOpNode::make( tensor_op->name + "." + scope, tensor_op->tag, new_axis, tensor_op->reduce_axis, tensor_op->schedulable_ndim, - tensor_op->intrin, tensor_op->inputs, new_regions); + tensor_op->intrin, tensor_op->inputs, new_regions, new_scalar_inputs); // axis will be used in generating compute op Array compute_axis = tensor_op->axis; diff --git a/tests/python/unittest/test_lang_schedule.py b/tests/python/unittest/test_lang_schedule.py index fb33e2b6fc0a..627921546f3c 100644 --- a/tests/python/unittest/test_lang_schedule.py +++ b/tests/python/unittest/test_lang_schedule.py @@ -209,11 +209,43 @@ def intrin_func(ins, outs): assert(s[z].iter_var_attrs[xi].tensor_intrin == intrin) assert(s[z].iter_var_attrs[xi].iter_type == tvm.schedule.IterVar.Tensorized) +def test_tensor_intrin_scalar_params(): + n = tvm.var("n") + x = tvm.placeholder((n,), name='x') + v = tvm.var("v") + w = tvm.var("w") + z = tvm.compute((n,), lambda i: x[i]*v + w, name='z') + + def intrin_func(ins, outs, sp): + assert(isinstance(ins[0], tvm.schedule.Buffer)) + assert(ins[0].shape[0] == n) + assert(sp[0] == v) + assert(sp[1] == w) + return tvm.call_packed("hw_func", ins[0].data, outs[0].data, sp[0], sp[1]) + + with tvm.build_config(offset_factor=1): + intrin = tvm.decl_tensor_intrin(z.op, intrin_func, scalar_params=[v, w]) + assert intrin.op == z.op + assert intrin.reduce_init is None + assert tuple(intrin.inputs) == tuple(z.op.input_tensors) + assert(intrin.buffers[0].shape[0] == n) + assert tuple(intrin.scalar_params) == tuple((v, w)) + + A = tvm.placeholder((10,10), name='A') + # Pass scalar inputs to the TensorIntrin, interleaved with tensor inputs + C = tvm.compute((10,10), lambda i, j: intrin(i*i, A[i, j], i+j), name="C") + s = tvm.create_schedule(C.op) + stmt = tvm.lower(s, [A, C], simple_mode=True) + assert isinstance(stmt.body.body.body, tvm.stmt.Evaluate) + assert len(stmt.body.body.body.value.args) == 5 + assert str(stmt.body.body.body.value.args[3]) == "(i*i)" + assert str(stmt.body.body.body.value.args[4]) == "(i + j)" if __name__ == "__main__": test_singleton() test_pragma() test_tensor_intrin() + test_tensor_intrin_scalar_params() test_rfactor() test_schedule_create() test_reorder()