Skip to content

Commit

Permalink
[TensorOp] Add testcase for scheduling tensor_compute_op.
Browse files Browse the repository at this point in the history
  • Loading branch information
ZihengJiang committed Sep 24, 2018
1 parent 753cb03 commit cf0134d
Show file tree
Hide file tree
Showing 9 changed files with 344 additions and 316 deletions.
7 changes: 5 additions & 2 deletions include/tvm/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ class TensorComputeOpNode : public OperationNode {
public:
Array<IterVar> axis;

Array<IterVar> out_axis;

Array<IterVar> tensor_axis;

Array<IterVar> reduce_axis;
Expand Down Expand Up @@ -229,20 +231,21 @@ class TensorComputeOpNode : public OperationNode {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("axis", &axis);
v->Visit("out_axis", &out_axis);
v->Visit("tensor_axis", &tensor_axis);
v->Visit("reduce_axis", &reduce_axis);
v->Visit("inputs", &inputs);
}

static Operation make(std::string name,
std::string tag,
Array<IterVar> axis,
Array<IterVar> out_axis,
Array<IterVar> tensor_axis,
TensorIntrinCall intrin_call);

static Operation make(std::string name,
std::string tag,
Array<IterVar> axis,
Array<IterVar> out_axis,
Array<IterVar> tensor_axis,
Array<IterVar> reduce_axis,
Array<Tensor> tensors,
Expand Down
8 changes: 0 additions & 8 deletions include/tvm/tensor_intrin.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,6 @@ class TensorIntrin : public NodeRef {
*/
inline const TensorIntrinNode* operator->() const;

// template<typename... Args>
// inline Stmt operator()(Args&& ...args) const {
// Array<Expr> inputs{std::forward<Args>(args)...};
// return operator()(inputs);
// }

// TVM_DLL TensorIntrinCall operator()(Array<Expr> inputs) const;

/*! \brief specify container node */
using ContainerType = TensorIntrinNode;
};
Expand Down
91 changes: 4 additions & 87 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import absolute_import as _abs

from numbers import Integral as _Integral
from collections import namedtuple

from ._ffi.base import string_types
from ._ffi.node import register_node, NodeBase
Expand Down Expand Up @@ -244,6 +243,8 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.current.tag
shape = (shape,) if isinstance(shape, _expr.Expr) else shape
# for python3
shape = tuple([int(s) if isinstance(s, float) else s for s in shape])
ndim = len(shape)
code = fcompute.__code__

Expand All @@ -254,7 +255,6 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
arg_names = code.co_varnames[:code.co_argcount]
out_ndim = code.co_argcount

# TODO check ndim, arg_names
if out_ndim != len(arg_names):
raise ValueError("fcompute do not match dimension, ndim=%d" % ndim)

Expand All @@ -264,8 +264,8 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
if isinstance(body, _tensor.TensorIntrinCall):
tensor_var = []
for i, s in enumerate(shape[out_ndim:]):
name = "ax" + str(i)
tensor_var.append(_IterVar((0, s), name, 4))
var_name = "ax" + str(i)
tensor_var.append(_IterVar((0, s), var_name, 4))
op_node = _api_internal._TensorComputeOp(name,
tag,
dim_var,
Expand All @@ -275,7 +275,6 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
if not isinstance(body, (list, tuple)):
body = [body]
body = convert(body)
# print('body: {0}'.format(body))
op_node = _api_internal._ComputeOp(
name, tag, attrs, dim_var, body)

Expand Down Expand Up @@ -353,88 +352,6 @@ def scan(init, update, state_placeholder, inputs=None, name="scan", tag="", attr
return res[0] if len(res) == 1 else res


def _get_region(tslice):
region = []
for idx in tslice.indices:
if isinstance(idx, slice):
assert idx.step is None
region.append(Range(idx.start, idx.stop))
else:
if isinstance(idx, _schedule.IterVar):
begin = idx.var
else:
begin = idx
region.append(_make.range_by_min_extent(begin, 1))
return region


# def tensor_op(out_dims,
# in_dims, # pylint: disable=unused-argument
# finputs,
# intrin,
# raxis=None,
# name='tensor_op',
# tag=""):
# """Construct new tensors with intrinsic.
#
# Parameters
# ----------
# out_dims: tuple
# The dimensions out of the tensorized region, which can be
# scheduled through `reorder`, `split`.
#
# in_dims: tuple
# The dimensions inside of the tensorized region, which cannot
# be manipulated.
#
# finputs: lambda function of out_dims -> list of TensorSlice
# Specifies involved regions of input tensors.
#
# tensor_intrin : TensorIntrin
# The tensor intrinsic used for computation.
#
# raxis : IterVar
# An iteration variable representing the value.
#
# name: str, optional
# The name hint of the tensor
#
# tag: str, optional
# Additonal tag information about the compute.
# """
# if _tag.TagScope.current is not None:
# if tag != "":
# raise ValueError("nested tag is not allowed for now")
# tag = _tag.TagScope.current.tag
#
# code = finputs.__code__
# if finputs.__code__.co_argcount == 0:
# arg_names = ["i%d" % i for i in range(ndim)]
# else:
# arg_names = code.co_varnames[:code.co_argcount]
#
# if len(out_dims) != len(arg_names):
# raise ValueError("finputs do not match dimension, ndim=%d" % out_dims)
#
# out_var = [_IterVar((0, extent), arg_name, 0)
# for arg_name, extent in zip(arg_names, out_dims)]
# if isinstance(raxis, _schedule.IterVar):
# raxis = [raxis]
# if raxis is None:
# raxis = []
# tensor_regions = finputs(*[v.var for v in out_var])
#
# op = _api_internal._TensorOp(name,
# tag,
# out_var,
# raxis,
# [x.tensor for x in tensor_regions],
# [_get_region(x) for x in tensor_regions],
# intrin)
# # only support single output
# return op.output(0)


def extern(shape,
inputs,
fcompute,
Expand Down
12 changes: 6 additions & 6 deletions python/tvm/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,12 @@ def reduce_axis(self):
return self.__getattr__("reduce_axis")


@register_node
class TensorComputeOp(Operation):
"""Tensor operation."""
pass


@register_node
class ScanOp(Operation):
"""Scan operation."""
Expand All @@ -174,9 +180,3 @@ def scan_axis(self):
class ExternOp(Operation):
"""Extern operation."""
pass


@register_node
class TensorComputeOp(Operation):
"""Tensor operation."""
pass
18 changes: 17 additions & 1 deletion python/tvm/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,25 @@
from . import stmt as _stmt
from . import make as _make
from . import tensor as _tensor
from . import schedule as _schedule
from .build_module import current_build_config
from ._ffi.node import NodeBase, register_node


def _get_region(tslice):
region = []
for idx in tslice.indices:
if isinstance(idx, slice):
assert idx.step is None
region.append(_api.Range(idx.start, idx.stop))
else:
if isinstance(idx, _schedule.IterVar):
begin = idx.var
else:
begin = idx
region.append(_make.range_by_min_extent(begin, 1))
return region

@register_node
class TensorIntrin(NodeBase):
"""Tensor intrinsic functions for certain computation.
Expand All @@ -19,7 +35,7 @@ class TensorIntrin(NodeBase):
"""
def __call__(self, *args, **kwargs):
tensors = [x.tensor for x in args]
regions = [_api._get_region(x) for x in args]
regions = [_get_region(x) for x in args]
reduce_axis = []
if "reduce_axis" in kwargs:
reduce_axis = kwargs["reduce_axis"]
Expand Down
6 changes: 0 additions & 6 deletions src/lang/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,6 @@ Tensor Operation::output(size_t i) const {
return Tensor(node);
}

// TensorIntrinCall TensorIntrin::operator()(Array<Expr> inputs) const {
// using HalideIR::Internal::Call;
// LOG(FATAL) << "CallTensorIntrin";
// CHECK_EQ(tensors.size(), regions.size());
// }

Tensor TensorNode::make(Array<Expr> shape,
Type dtype,
Operation op,
Expand Down
Loading

0 comments on commit cf0134d

Please sign in to comment.