Skip to content

Commit

Permalink
[TensorOp] Support for intrin(..) and rename to TensorComputeOp.
Browse files Browse the repository at this point in the history
  • Loading branch information
ZihengJiang committed Sep 16, 2018
1 parent c2389b1 commit 1dfa7b4
Show file tree
Hide file tree
Showing 10 changed files with 319 additions and 183 deletions.
19 changes: 13 additions & 6 deletions include/tvm/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ class PlaceholderOpNode : public OperationNode {
TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode, OperationNode);
};

class TensorOpNode : public OperationNode {
class TensorComputeOpNode : public OperationNode {
public:
Array<IterVar> axis;

Expand All @@ -197,7 +197,7 @@ class TensorOpNode : public OperationNode {
TensorIntrin intrin;

/*! \brief constructor */
TensorOpNode() {}
TensorComputeOpNode() {}

// override functions
int num_outputs() const final;
Expand Down Expand Up @@ -237,13 +237,20 @@ class TensorOpNode : public OperationNode {
static Operation make(std::string name,
std::string tag,
Array<IterVar> axis,
Array<IterVar> tensor_axis,
TensorIntrinCall intrin_call);

static Operation make(std::string name,
std::string tag,
Array<IterVar> axis,
Array<IterVar> tensor_axis,
Array<IterVar> reduce_axis,
Array<Tensor> inputs,
Array<Region> input_regions,
Array<Tensor> tensors,
Array<Region> regions,
TensorIntrin intrin);

static constexpr const char* _type_key = "TensorOp";
TVM_DECLARE_NODE_TYPE_INFO(TensorOpNode, OperationNode);
static constexpr const char* _type_key = "TensorComputeOp";
TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, OperationNode);
};

/*!
Expand Down
55 changes: 55 additions & 0 deletions include/tvm/tensor_intrin.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ class TensorIntrin : public NodeRef {
*/
inline const TensorIntrinNode* operator->() const;

// template<typename... Args>
// inline Stmt operator()(Args&& ...args) const {
// Array<Expr> inputs{std::forward<Args>(args)...};
// return operator()(inputs);
// }

// TVM_DLL TensorIntrinCall operator()(Array<Expr> inputs) const;

/*! \brief specify container node */
using ContainerType = TensorIntrinNode;
};
Expand Down Expand Up @@ -89,5 +97,52 @@ class TensorIntrinNode : public Node {
inline const TensorIntrinNode* TensorIntrin::operator->() const {
return static_cast<const TensorIntrinNode*>(node_.get());
}


// Internal node container of tensor intrinsics.
class TensorIntrinCallNode;

/*! \brief Tensor intrinsic node. */
class TensorIntrinCall : public NodeRef {
public:
TensorIntrinCall() {}
explicit TensorIntrinCall(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const TensorIntrinCallNode* operator->() const;

/*! \brief specify container node */
using ContainerType = TensorIntrinCallNode;
};

class TensorIntrinCallNode : public Node {
public:
TensorIntrin intrin;
Array<Tensor> tensors;
Array<Region> regions;
Array<IterVar> reduce_axis;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("intrin", &intrin);
v->Visit("tensors", &tensors);
v->Visit("regions", &regions);
v->Visit("reduce_axis", &reduce_axis);
}

static TensorIntrinCall make(TensorIntrin intrin,
Array<Tensor> tensors,
Array<Region> regions,
Array<IterVar> reduce_axis);

static constexpr const char* _type_key = "TensorIntrinCall";
TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinCallNode, Node);
};

inline const TensorIntrinCallNode* TensorIntrinCall::operator->() const {
return static_cast<const TensorIntrinCallNode*>(node_.get());
}

} // namespace tvm
#endif // TVM_TENSOR_INTRIN_H_
188 changes: 103 additions & 85 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,21 +226,38 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
ndim = len(shape)
code = fcompute.__code__

if fcompute.__code__.co_argcount == 0:
out_ndim = ndim
if code.co_argcount == 0:
arg_names = ["i%d" % i for i in range(ndim)]
else:
arg_names = code.co_varnames[:code.co_argcount]
out_ndim = code.co_argcount

if ndim != len(arg_names):
# TODO check ndim, arg_names
if out_ndim != len(arg_names):
raise ValueError("fcompute do not match dimension, ndim=%d" % ndim)

dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape)]
dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape[:out_ndim])]
body = fcompute(*[v.var for v in dim_var])
if not isinstance(body, (list, tuple)):
body = [body]
body = convert(body)
op_node = _api_internal._ComputeOp(
name, tag, attrs, dim_var, body)

if isinstance(body, _tensor.TensorIntrinCall):
tensor_var = []
for i, s in enumerate(shape[out_ndim:]):
name = "ax" + str(i)
tensor_var.append(_IterVar((0, s), name, 4))
op_node = _api_internal._TensorComputeOp(name,
tag,
dim_var,
tensor_var,
body)
else:
if not isinstance(body, (list, tuple)):
body = [body]
body = convert(body)
# print('body: {0}'.format(body))
op_node = _api_internal._ComputeOp(
name, tag, attrs, dim_var, body)

num = op_node.num_outputs
outputs = tuple(op_node.output(i) for i in range(num))
return outputs[0] if num == 1 else outputs
Expand Down Expand Up @@ -315,84 +332,85 @@ def scan(init, update, state_placeholder, inputs=None, name="scan", tag="", attr
return res[0] if len(res) == 1 else res


def tensor_op(out_dims,
in_dims, # pylint: disable=unused-argument
finputs,
intrin,
raxis=None,
name='tensor_op',
tag=""):
"""Construct new tensors with intrinsic.
Parameters
----------
out_dims: tuple
The dimensions out of the tensorized region, which can be
scheduled through `reorder`, `split`.
in_dims: tuple
The dimensions inside of the tensorized region, which cannot
be manipulated.
finputs: lambda function of out_dims -> list of TensorSlice
Specifies involved regions of input tensors.
tensor_intrin : TensorIntrin
The tensor intrinsic used for computation.
raxis : IterVar
An iteration variable representing the value.
name: str, optional
The name hint of the tensor
tag: str, optional
Additonal tag information about the compute.
"""
def _get_region(tslice):
region = []
for idx in tslice.indices:
if isinstance(idx, slice):
assert idx.step is None
region.append(Range(idx.start, idx.stop))
def _get_region(tslice):
region = []
for idx in tslice.indices:
if isinstance(idx, slice):
assert idx.step is None
region.append(Range(idx.start, idx.stop))
else:
if isinstance(idx, _schedule.IterVar):
begin = idx.var
else:
if isinstance(idx, _schedule.IterVar):
begin = idx.var
else:
begin = idx
region.append(_make.range_by_min_extent(begin, 1))
return region

if _tag.TagScope.current is not None:
if tag != "":
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.current.tag

code = finputs.__code__
if finputs.__code__.co_argcount == 0:
arg_names = ["i%d" % i for i in range(ndim)]
else:
arg_names = code.co_varnames[:code.co_argcount]

if len(out_dims) != len(arg_names):
raise ValueError("finputs do not match dimension, ndim=%d" % out_dims)

out_var = [_IterVar((0, extent), arg_name, 0) for arg_name, extent in zip(arg_names, out_dims)]
if isinstance(raxis, _schedule.IterVar):
raxis = [raxis]
if raxis is None:
raxis = []
tensor_regions = finputs(*[v.var for v in out_var])

op = _api_internal._TensorOp(name,
tag,
out_var,
raxis,
[x.tensor for x in tensor_regions],
[_get_region(x) for x in tensor_regions],
intrin)
# only support single output
return op.output(0)
begin = idx
region.append(_make.range_by_min_extent(begin, 1))
return region


# def tensor_op(out_dims,
# in_dims, # pylint: disable=unused-argument
# finputs,
# intrin,
# raxis=None,
# name='tensor_op',
# tag=""):
# """Construct new tensors with intrinsic.
#
# Parameters
# ----------
# out_dims: tuple
# The dimensions out of the tensorized region, which can be
# scheduled through `reorder`, `split`.
#
# in_dims: tuple
# The dimensions inside of the tensorized region, which cannot
# be manipulated.
#
# finputs: lambda function of out_dims -> list of TensorSlice
# Specifies involved regions of input tensors.
#
# tensor_intrin : TensorIntrin
# The tensor intrinsic used for computation.
#
# raxis : IterVar
# An iteration variable representing the value.
#
# name: str, optional
# The name hint of the tensor
#
# tag: str, optional
# Additonal tag information about the compute.
# """
# if _tag.TagScope.current is not None:
# if tag != "":
# raise ValueError("nested tag is not allowed for now")
# tag = _tag.TagScope.current.tag
#
# code = finputs.__code__
# if finputs.__code__.co_argcount == 0:
# arg_names = ["i%d" % i for i in range(ndim)]
# else:
# arg_names = code.co_varnames[:code.co_argcount]
#
# if len(out_dims) != len(arg_names):
# raise ValueError("finputs do not match dimension, ndim=%d" % out_dims)
#
# out_var = [_IterVar((0, extent), arg_name, 0) for arg_name, extent in zip(arg_names, out_dims)]
# if isinstance(raxis, _schedule.IterVar):
# raxis = [raxis]
# if raxis is None:
# raxis = []
# tensor_regions = finputs(*[v.var for v in out_var])
#
# op = _api_internal._TensorOp(name,
# tag,
# out_var,
# raxis,
# [x.tensor for x in tensor_regions],
# [_get_region(x) for x in tensor_regions],
# intrin)
# # only support single output
# return op.output(0)


def extern(shape,
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ def dtype(self):
"""Data content of the tensor."""
return self.tensor.dtype

@register_node
class TensorIntrinCall(NodeBase):
"""Intermediate structure for calling a tensor intrinsic."""
pass


itervar_cls = None

Expand Down Expand Up @@ -167,6 +172,6 @@ class ExternOp(Operation):


@register_node
class TensorOp(Operation):
class TensorComputeOp(Operation):
"""Tensor operation."""
pass
7 changes: 5 additions & 2 deletions python/tvm/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ class TensorIntrin(NodeBase):
--------
decl_tensor_intrin: Construct a TensorIntrin
"""
pass

def __call__(self, *args):
tensors = [x.tensor for x in args]
regions = [_api._get_region(x) for x in args]
# TODO
return _api_internal._TensorIntrinCall(self, tensors, regions, [])

def decl_tensor_intrin(op,
fcompute,
Expand Down
22 changes: 14 additions & 8 deletions src/api/api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,14 @@ TVM_REGISTER_API("_TensorIntrin")
args[6]);
});

TVM_REGISTER_API("_TensorIntrinCall")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TensorIntrinCallNode::make(args[0],
args[1],
args[2],
args[3]);
});

TVM_REGISTER_API("_TensorEqual")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Tensor() == args[1].operator Tensor();
Expand Down Expand Up @@ -278,15 +286,13 @@ TVM_REGISTER_API("_ScanOp")
args[7]);
});

TVM_REGISTER_API("_TensorOp")
TVM_REGISTER_API("_TensorComputeOp")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TensorOpNode::make(args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[6]);
*ret = TensorComputeOpNode::make(args[0],
args[1],
args[2],
args[3],
args[4]);
});

TVM_REGISTER_API("_ExternOp")
Expand Down
Loading

0 comments on commit 1dfa7b4

Please sign in to comment.