Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend TensorComputeOp to allow scalar inputs (#2606). #3300

Merged
merged 1 commit into from
Jun 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion include/tvm/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ class TensorComputeOpNode : public BaseComputeOpNode {
Array<Tensor> inputs;
/*! \brief region of input tensors */
Array<Region> input_regions;
/*! \brief scalar expression inputs */
Array<Expr> scalar_inputs;
/*! \brief constructor */
TensorComputeOpNode() {}
// override functions
Expand Down Expand Up @@ -314,6 +316,7 @@ class TensorComputeOpNode : public BaseComputeOpNode {
v->Visit("intrin", &intrin);
v->Visit("inputs", &inputs);
v->Visit("input_regions", &input_regions);
v->Visit("scalar_inputs", &scalar_inputs);
}
static Operation make(std::string name,
std::string tag,
Expand All @@ -322,7 +325,8 @@ class TensorComputeOpNode : public BaseComputeOpNode {
int schedulable_ndim,
TensorIntrin intrin,
Array<Tensor> tensors,
Array<Region> regions);
Array<Region> regions,
Array<Expr> scalar_inputs);

static constexpr const char* _type_key = "TensorComputeOp";
TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, BaseComputeOpNode);
Expand Down
16 changes: 15 additions & 1 deletion include/tvm/tensor_intrin.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ class TensorIntrinNode : public Node {
* When it is a constant, it means we can only take data in that shape.
*/
Array<Buffer> buffers;
/*! \brief List of scalar variables, used in body. These placeholders
* will be bound to expressions passed in when the TensorIntrin is called
* from a TensorComputeOp.
*/
Array<Var> scalar_params;
/*! \brief The normal statement to execute the intrinsic */
Stmt body;
/*!
Expand All @@ -87,6 +92,7 @@ class TensorIntrinNode : public Node {
v->Visit("op", &op);
v->Visit("inputs", &inputs);
v->Visit("buffers", &buffers);
v->Visit("scalar_params", &scalar_params);
v->Visit("body", &body);
v->Visit("reduce_init", &reduce_init);
v->Visit("reduce_update", &reduce_update);
Expand All @@ -96,6 +102,7 @@ class TensorIntrinNode : public Node {
Operation op,
Array<Tensor> inputs,
Array<Buffer> buffers,
Array<Var> scalar_params,
Stmt body,
Stmt reduce_init,
Stmt reduce_update);
Expand Down Expand Up @@ -134,22 +141,29 @@ class TensorIntrinCallNode : public Node {
Array<Tensor> tensors;
/*! \brief regions of input tensors */
Array<Region> regions;


/*!
* \brief IterVar on each reduction axis, if the
* intrin will use the reduce axis
*/
Array<IterVar> reduce_axis;

/*! \brief scalar expression inputs */
Array<Expr> scalar_inputs;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("intrin", &intrin);
v->Visit("tensors", &tensors);
v->Visit("regions", &regions);
v->Visit("reduce_axis", &reduce_axis);
v->Visit("scalar_inputs", &scalar_inputs);
}
static TensorIntrinCall make(TensorIntrin intrin,
Array<Tensor> tensors,
Array<Region> regions,
Array<IterVar> reduce_axis);
Array<IterVar> reduce_axis,
Array<Expr> scalar_inputs);

static constexpr const char* _type_key = "TensorIntrinCall";
TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinCallNode, Node);
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,8 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
out_ndim,
body.intrin,
body.tensors,
body.regions)
body.regions,
body.scalar_inputs)
else:
if not isinstance(body, (list, tuple)):
body = [body]
Expand Down
22 changes: 16 additions & 6 deletions python/tvm/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,23 @@ class TensorIntrin(NodeBase):
decl_tensor_intrin: Construct a TensorIntrin
"""
def __call__(self, *args, **kwargs):
tensors = [x.tensor for x in args]
regions = [_get_region(x) for x in args]
tensors = [x.tensor for x in args if isinstance(x, _tensor.TensorSlice)]
scalar_inputs = [x for x in args if not isinstance(x, _tensor.TensorSlice)]
regions = [_get_region(x) for x in args if isinstance(x, _tensor.TensorSlice)]
reduce_axis = []
if "reduce_axis" in kwargs:
reduce_axis = kwargs["reduce_axis"]
if not isinstance(reduce_axis, (list, tuple)):
reduce_axis = [reduce_axis]
reduce_axis = _api.convert(reduce_axis)
return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis)
if scalar_inputs:
scalar_inputs = _api.convert(scalar_inputs)
return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis, scalar_inputs)

def decl_tensor_intrin(op,
fcompute,
name="tensor_intrin",
binds=None):
binds=None, scalar_params=None):
"""Declare a tensor intrinsic function.

Parameters
Expand Down Expand Up @@ -96,6 +99,9 @@ def decl_tensor_intrin(op,
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.

scalar_params: a list of variables used by op, whose values will be passed
as scalar_inputs when the tensor intrinsic is called.

Returns
-------
intrin: TensorIntrin
Expand All @@ -122,11 +128,15 @@ def decl_tensor_intrin(op,
offset_factor=cfg.offset_factor))
binds_list.append(buf)

body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):])
if scalar_params:
body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):], scalar_params)
else:
body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):])
scalar_params = []
if isinstance(body, (_expr.Expr, _stmt.Stmt)):
body = [body]
body = [_make.Evaluate(x) if isinstance(x, _expr.Expr) else x for x in body]
if len(body) < 3:
body += [None] * (3 - len(body))
return _api_internal._TensorIntrin(
name, op, inputs, binds_list, *body)
name, op, inputs, binds_list, scalar_params, *body)
6 changes: 5 additions & 1 deletion src/lang/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ TensorIntrin TensorIntrinNode::make(std::string name,
Operation op,
Array<Tensor> inputs,
Array<Buffer> buffers,
Array<Var> scalar_params,
Stmt body,
Stmt reduce_init,
Stmt reduce_update) {
Expand All @@ -91,6 +92,7 @@ TensorIntrin TensorIntrinNode::make(std::string name,
n->op = std::move(op);
n->inputs = std::move(inputs);
n->buffers = std::move(buffers);
n->scalar_params = std::move(scalar_params);
n->body = std::move(body);
n->reduce_init = std::move(reduce_init);
n->reduce_update = std::move(reduce_update);
Expand All @@ -110,12 +112,14 @@ TVM_REGISTER_NODE_TYPE(TensorIntrinNode);
TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin,
Array<Tensor> tensors,
Array<Region> regions,
Array<IterVar> reduce_axis) {
Array<IterVar> reduce_axis,
Array<Expr> scalar_inputs) {
auto n = make_node<TensorIntrinCallNode>();
n->intrin = std::move(intrin);
n->tensors = std::move(tensors);
n->regions = std::move(regions);
n->reduce_axis = std::move(reduce_axis);
n->scalar_inputs = std::move(scalar_inputs);
return TensorIntrinCall(n);
}

Expand Down
17 changes: 16 additions & 1 deletion src/op/tensor_compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ Operation TensorComputeOpNode::make(std::string name,
int schedulable_ndim,
TensorIntrin intrin,
Array<Tensor> tensors,
Array<Region> regions) {
Array<Region> regions,
Array<Expr> scalar_inputs) {
auto n = make_node<TensorComputeOpNode>();
n->name = std::move(name);
n->tag = std::move(tag);
Expand All @@ -68,6 +69,7 @@ Operation TensorComputeOpNode::make(std::string name,
n->intrin = std::move(intrin);
n->inputs = std::move(tensors);
n->input_regions = std::move(regions);
n->scalar_inputs = std::move(scalar_inputs);
return Operation(n);
}

Expand Down Expand Up @@ -184,6 +186,19 @@ Stmt TensorComputeOpNode::BuildProvide(
std::unordered_map<const Variable*, Expr> vmap;
ir::ArgBinder binder(&vmap);

// Map the expressions passed in the call to the TensorIntrin, to the placeholder
// variables
Array<Expr> user_expr = this->scalar_inputs;
Array<Var> scalar_params = this->intrin->scalar_params;
Array<Expr> sp_expr;
for (auto sp : scalar_params) {
Expr esp = sp;
sp_expr.push_back(esp);
}
CHECK_EQ(sp_expr.size(), user_expr.size());
// TODO(jdavies-huawei): what name should be used here?
binder.BindArray(sp_expr, user_expr, this->name);

size_t tloc = stage->leaf_iter_vars.size();
ComputeLoopNest n = ComputeLoopNest::make(this, stage, dom_map, debug_keep_trivial_loop);

Expand Down
7 changes: 6 additions & 1 deletion src/schedule/schedule_dataflow_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -410,10 +410,15 @@ Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch,
new_regions.push_back(region);
}

Array<Expr> new_scalar_inputs;
for (Expr old_input : tensor_op->scalar_inputs) {
new_scalar_inputs.push_back(VarReplacer(vsub2newvar).Mutate(old_input));
}

Operation cache_op = TensorComputeOpNode::make(
tensor_op->name + "." + scope, tensor_op->tag, new_axis,
tensor_op->reduce_axis, tensor_op->schedulable_ndim,
tensor_op->intrin, tensor_op->inputs, new_regions);
tensor_op->intrin, tensor_op->inputs, new_regions, new_scalar_inputs);

// axis will be used in generating compute op
Array<IterVar> compute_axis = tensor_op->axis;
Expand Down
32 changes: 32 additions & 0 deletions tests/python/unittest/test_lang_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,43 @@ def intrin_func(ins, outs):
assert(s[z].iter_var_attrs[xi].tensor_intrin == intrin)
assert(s[z].iter_var_attrs[xi].iter_type == tvm.schedule.IterVar.Tensorized)

def test_tensor_intrin_scalar_params():
n = tvm.var("n")
x = tvm.placeholder((n,), name='x')
v = tvm.var("v")
w = tvm.var("w")
z = tvm.compute((n,), lambda i: x[i]*v + w, name='z')

def intrin_func(ins, outs, sp):
assert(isinstance(ins[0], tvm.schedule.Buffer))
assert(ins[0].shape[0] == n)
assert(sp[0] == v)
assert(sp[1] == w)
return tvm.call_packed("hw_func", ins[0].data, outs[0].data, sp[0], sp[1])

with tvm.build_config(offset_factor=1):
intrin = tvm.decl_tensor_intrin(z.op, intrin_func, scalar_params=[v, w])
assert intrin.op == z.op
assert intrin.reduce_init is None
assert tuple(intrin.inputs) == tuple(z.op.input_tensors)
assert(intrin.buffers[0].shape[0] == n)
assert tuple(intrin.scalar_params) == tuple((v, w))

A = tvm.placeholder((10,10), name='A')
# Pass scalar inputs to the TensorIntrin, interleaved with tensor inputs
C = tvm.compute((10,10), lambda i, j: intrin(i*i, A[i, j], i+j), name="C")
s = tvm.create_schedule(C.op)
stmt = tvm.lower(s, [A, C], simple_mode=True)
assert isinstance(stmt.body.body.body, tvm.stmt.Evaluate)
assert len(stmt.body.body.body.value.args) == 5
assert str(stmt.body.body.body.value.args[3]) == "(i*i)"
assert str(stmt.body.body.body.value.args[4]) == "(i + j)"

if __name__ == "__main__":
test_singleton()
test_pragma()
test_tensor_intrin()
test_tensor_intrin_scalar_params()
test_rfactor()
test_schedule_create()
test_reorder()
Expand Down