Skip to content

Commit

Permalink
Extend TensorComputeOp to allow scalar inputs (#2606).
Browse files Browse the repository at this point in the history
  • Loading branch information
jdavies-huawei committed Jun 6, 2019
1 parent d7bc4fd commit def8125
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 12 deletions.
6 changes: 5 additions & 1 deletion include/tvm/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ class TensorComputeOpNode : public BaseComputeOpNode {
Array<Tensor> inputs;
/*! \brief region of input tensors */
Array<Region> input_regions;
/*! \brief scalar expression inputs */
Array<Expr> scalar_inputs;
/*! \brief constructor */
TensorComputeOpNode() {}
// override functions
Expand Down Expand Up @@ -314,6 +316,7 @@ class TensorComputeOpNode : public BaseComputeOpNode {
v->Visit("intrin", &intrin);
v->Visit("inputs", &inputs);
v->Visit("input_regions", &input_regions);
v->Visit("scalar_inputs", &scalar_inputs);
}
static Operation make(std::string name,
std::string tag,
Expand All @@ -322,7 +325,8 @@ class TensorComputeOpNode : public BaseComputeOpNode {
int schedulable_ndim,
TensorIntrin intrin,
Array<Tensor> tensors,
Array<Region> regions);
Array<Region> regions,
Array<Expr> scalar_inputs);

static constexpr const char* _type_key = "TensorComputeOp";
TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, BaseComputeOpNode);
Expand Down
16 changes: 15 additions & 1 deletion include/tvm/tensor_intrin.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ class TensorIntrinNode : public Node {
* When it is a constant, it means we can only take data in that shape.
*/
Array<Buffer> buffers;
/*! \brief List of scalar variables, used in body. These placeholders
* will be bound to expressions passed in when the TensorIntrin is called
* from a TensorComputeOp.
*/
Array<Var> scalar_params;
/*! \brief The normal statement to execute the intrinsic */
Stmt body;
/*!
Expand All @@ -87,6 +92,7 @@ class TensorIntrinNode : public Node {
v->Visit("op", &op);
v->Visit("inputs", &inputs);
v->Visit("buffers", &buffers);
v->Visit("scalar_params", &scalar_params);
v->Visit("body", &body);
v->Visit("reduce_init", &reduce_init);
v->Visit("reduce_update", &reduce_update);
Expand All @@ -96,6 +102,7 @@ class TensorIntrinNode : public Node {
Operation op,
Array<Tensor> inputs,
Array<Buffer> buffers,
Array<Var> scalar_params,
Stmt body,
Stmt reduce_init,
Stmt reduce_update);
Expand Down Expand Up @@ -134,22 +141,29 @@ class TensorIntrinCallNode : public Node {
Array<Tensor> tensors;
/*! \brief regions of input tensors */
Array<Region> regions;


/*!
* \brief IterVar on each reduction axis, if the
* intrin will use the reduce axis
*/
Array<IterVar> reduce_axis;

/*! \brief scalar expression inputs */
Array<Expr> scalar_inputs;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("intrin", &intrin);
v->Visit("tensors", &tensors);
v->Visit("regions", &regions);
v->Visit("reduce_axis", &reduce_axis);
v->Visit("scalar_inputs", &scalar_inputs);
}
static TensorIntrinCall make(TensorIntrin intrin,
Array<Tensor> tensors,
Array<Region> regions,
Array<IterVar> reduce_axis);
Array<IterVar> reduce_axis,
Array<Expr> scalar_inputs);

static constexpr const char* _type_key = "TensorIntrinCall";
TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinCallNode, Node);
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,8 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
out_ndim,
body.intrin,
body.tensors,
body.regions)
body.regions,
body.scalar_inputs)
else:
if not isinstance(body, (list, tuple)):
body = [body]
Expand Down
22 changes: 16 additions & 6 deletions python/tvm/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,23 @@ class TensorIntrin(NodeBase):
decl_tensor_intrin: Construct a TensorIntrin
"""
def __call__(self, *args, **kwargs):
tensors = [x.tensor for x in args]
regions = [_get_region(x) for x in args]
tensors = [x.tensor for x in args if isinstance(x, _tensor.TensorSlice)]
scalar_inputs = [x for x in args if not isinstance(x, _tensor.TensorSlice)]
regions = [_get_region(x) for x in args if isinstance(x, _tensor.TensorSlice)]
reduce_axis = []
if "reduce_axis" in kwargs:
reduce_axis = kwargs["reduce_axis"]
if not isinstance(reduce_axis, (list, tuple)):
reduce_axis = [reduce_axis]
reduce_axis = _api.convert(reduce_axis)
return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis)
if scalar_inputs:
scalar_inputs = _api.convert(scalar_inputs)
return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis, scalar_inputs)

def decl_tensor_intrin(op,
fcompute,
name="tensor_intrin",
binds=None):
binds=None, scalar_params=None):
"""Declare a tensor intrinsic function.
Parameters
Expand Down Expand Up @@ -96,6 +99,9 @@ def decl_tensor_intrin(op,
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.
scalar_params: a list of variables used by op, whose values will be passed
as scalar_inputs when the tensor intrinsic is called.
Returns
-------
intrin: TensorIntrin
Expand All @@ -122,11 +128,15 @@ def decl_tensor_intrin(op,
offset_factor=cfg.offset_factor))
binds_list.append(buf)

body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):])
if scalar_params:
body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):], scalar_params)
else:
body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):])
scalar_params = []
if isinstance(body, (_expr.Expr, _stmt.Stmt)):
body = [body]
body = [_make.Evaluate(x) if isinstance(x, _expr.Expr) else x for x in body]
if len(body) < 3:
body += [None] * (3 - len(body))
return _api_internal._TensorIntrin(
name, op, inputs, binds_list, *body)
name, op, inputs, binds_list, scalar_params, *body)
6 changes: 5 additions & 1 deletion src/lang/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ TensorIntrin TensorIntrinNode::make(std::string name,
Operation op,
Array<Tensor> inputs,
Array<Buffer> buffers,
Array<Var> scalar_params,
Stmt body,
Stmt reduce_init,
Stmt reduce_update) {
Expand All @@ -91,6 +92,7 @@ TensorIntrin TensorIntrinNode::make(std::string name,
n->op = std::move(op);
n->inputs = std::move(inputs);
n->buffers = std::move(buffers);
n->scalar_params = std::move(scalar_params);
n->body = std::move(body);
n->reduce_init = std::move(reduce_init);
n->reduce_update = std::move(reduce_update);
Expand All @@ -110,12 +112,14 @@ TVM_REGISTER_NODE_TYPE(TensorIntrinNode);
TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin,
Array<Tensor> tensors,
Array<Region> regions,
Array<IterVar> reduce_axis) {
Array<IterVar> reduce_axis,
Array<Expr> scalar_inputs) {
auto n = make_node<TensorIntrinCallNode>();
n->intrin = std::move(intrin);
n->tensors = std::move(tensors);
n->regions = std::move(regions);
n->reduce_axis = std::move(reduce_axis);
n->scalar_inputs = std::move(scalar_inputs);
return TensorIntrinCall(n);
}

Expand Down
17 changes: 16 additions & 1 deletion src/op/tensor_compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ Operation TensorComputeOpNode::make(std::string name,
int schedulable_ndim,
TensorIntrin intrin,
Array<Tensor> tensors,
Array<Region> regions) {
Array<Region> regions,
Array<Expr> scalar_inputs) {
auto n = make_node<TensorComputeOpNode>();
n->name = std::move(name);
n->tag = std::move(tag);
Expand All @@ -68,6 +69,7 @@ Operation TensorComputeOpNode::make(std::string name,
n->intrin = std::move(intrin);
n->inputs = std::move(tensors);
n->input_regions = std::move(regions);
n->scalar_inputs = std::move(scalar_inputs);
return Operation(n);
}

Expand Down Expand Up @@ -184,6 +186,19 @@ Stmt TensorComputeOpNode::BuildProvide(
std::unordered_map<const Variable*, Expr> vmap;
ir::ArgBinder binder(&vmap);

// Map the expressions passed in the call to the TensorIntrin, to the placeholder
// variables
Array<Expr> user_expr = this->scalar_inputs;
Array<Var> scalar_params = this->intrin->scalar_params;
Array<Expr> sp_expr;
for (auto sp : scalar_params) {
Expr esp = sp;
sp_expr.push_back(esp);
}
CHECK_EQ(sp_expr.size(), user_expr.size());
// TODO(jdavies-huawei): what name should be used here?
binder.BindArray(sp_expr, user_expr, this->name);

size_t tloc = stage->leaf_iter_vars.size();
ComputeLoopNest n = ComputeLoopNest::make(this, stage, dom_map, debug_keep_trivial_loop);

Expand Down
7 changes: 6 additions & 1 deletion src/schedule/schedule_dataflow_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -410,10 +410,15 @@ Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch,
new_regions.push_back(region);
}

Array<Expr> new_scalar_inputs;
for (Expr old_input : tensor_op->scalar_inputs) {
new_scalar_inputs.push_back(VarReplacer(vsub2newvar).Mutate(old_input));
}

Operation cache_op = TensorComputeOpNode::make(
tensor_op->name + "." + scope, tensor_op->tag, new_axis,
tensor_op->reduce_axis, tensor_op->schedulable_ndim,
tensor_op->intrin, tensor_op->inputs, new_regions);
tensor_op->intrin, tensor_op->inputs, new_regions, new_scalar_inputs);

// axis will be used in generating compute op
Array<IterVar> compute_axis = tensor_op->axis;
Expand Down
32 changes: 32 additions & 0 deletions tests/python/unittest/test_lang_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,43 @@ def intrin_func(ins, outs):
assert(s[z].iter_var_attrs[xi].tensor_intrin == intrin)
assert(s[z].iter_var_attrs[xi].iter_type == tvm.schedule.IterVar.Tensorized)

def test_tensor_intrin_scalar_params():
n = tvm.var("n")
x = tvm.placeholder((n,), name='x')
v = tvm.var("v")
w = tvm.var("w")
z = tvm.compute((n,), lambda i: x[i]*v + w, name='z')

def intrin_func(ins, outs, sp):
assert(isinstance(ins[0], tvm.schedule.Buffer))
assert(ins[0].shape[0] == n)
assert(sp[0] == v)
assert(sp[1] == w)
return tvm.call_packed("hw_func", ins[0].data, outs[0].data, sp[0], sp[1])

with tvm.build_config(offset_factor=1):
intrin = tvm.decl_tensor_intrin(z.op, intrin_func, scalar_params=[v, w])
assert intrin.op == z.op
assert intrin.reduce_init is None
assert tuple(intrin.inputs) == tuple(z.op.input_tensors)
assert(intrin.buffers[0].shape[0] == n)
assert tuple(intrin.scalar_params) == tuple((v, w))

A = tvm.placeholder((10,10), name='A')
# Pass scalar inputs to the TensorIntrin, interleaved with tensor inputs
C = tvm.compute((10,10), lambda i, j: intrin(i*i, A[i, j], i+j), name="C")
s = tvm.create_schedule(C.op)
stmt = tvm.lower(s, [A, C], simple_mode=True)
assert isinstance(stmt.body.body.body, tvm.stmt.Evaluate)
assert len(stmt.body.body.body.value.args) == 5
assert str(stmt.body.body.body.value.args[3]) == "(i*i)"
assert str(stmt.body.body.body.value.args[4]) == "(i + j)"

if __name__ == "__main__":
test_singleton()
test_pragma()
test_tensor_intrin()
test_tensor_intrin_scalar_params()
test_rfactor()
test_schedule_create()
test_reorder()
Expand Down

0 comments on commit def8125

Please sign in to comment.