Skip to content

Commit

Permalink
[Relay/TOPI][OP] Add arange op in Relay and TOPI (apache#2621)
Browse files Browse the repository at this point in the history
* Add arange op

* Update docs

* Fix bug

* add sanity check in relay and mxnet frontend mapping

* lint

* nits

* pylint

* don't allow empty output from arange

* Remove empty test for arange

* Fix bug and update doc
  • Loading branch information
icemelon authored and wweic committed Mar 9, 2019
1 parent 8a68378 commit 20ef1e4
Show file tree
Hide file tree
Showing 13 changed files with 309 additions and 22 deletions.
2 changes: 2 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ List of operators
topi.not_equal
topi.greater_equal
topi.less_equal
topi.arange
topi.image.resize


Expand Down Expand Up @@ -123,6 +124,7 @@ topi
.. autofunction:: topi.power
.. autofunction:: topi.greater
.. autofunction:: topi.less
.. autofunction:: topi.arange

topi.nn
~~~~~~~
Expand Down
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ This level enables additional math and transform operators.
tvm.relay.full_like
tvm.relay.cast
tvm.relay.split
tvm.relay.arange


**Level 4: Broadcast and Reductions**
Expand Down Expand Up @@ -216,6 +217,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.full_like
.. autofunction:: tvm.relay.cast
.. autofunction:: tvm.relay.split
.. autofunction:: tvm.relay.arange


Level 4 Definitions
Expand Down
19 changes: 19 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,25 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
}
}; // struct InitOpAttrs

/*! \brief Attributes used in arange operators */
struct ArangeAttrs : public tvm::AttrsNode<ArangeAttrs> {
tvm::Expr start;
tvm::Expr stop;
tvm::Expr step;
DataType dtype;

TVM_DECLARE_ATTRS(ArangeAttrs, "relay.attrs.ArangeAttrs") {
TVM_ATTR_FIELD(start).set_default(make_const(Float(32), 0))
.describe("Start of interval. The interval includes this value.");
TVM_ATTR_FIELD(stop)
.describe("Stop of interval. The interval does not include this value.");
TVM_ATTR_FIELD(step).set_default(make_const(Float(32), 1))
.describe("Spacing between values.");
TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>())
.describe("Target data type.");
}
}; // struct ArangeAttrs

/*! \brief Attributes used in squeeze operators */
struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
// use axis to make the name numpy compatible.
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,18 @@ def _mx_multibox_detection(inputs, attrs):
return _op.vision.nms(ret[0], ret[1], **new_attrs1)


def _mx_arange(inputs, attrs):
assert len(inputs) == 0
if attrs.get_int("repeat", 1) != 1:
raise RuntimeError("arange doesn't support repeat")
new_attrs = {}
new_attrs["start"] = attrs.get_float("start", 0)
new_attrs["stop"] = attrs.get_float("stop")
new_attrs["step"] = attrs.get_float("step", 1)
new_attrs["dtype"] = attrs.get_str("dtype", "float32")
return _op.arange(**new_attrs)


def _mx_roi_align(inputs, attrs):
new_attrs = {}
new_attrs["pooled_size"] = attrs.get_int_tuple("pooled_size")
Expand Down Expand Up @@ -362,6 +374,7 @@ def _mx_roi_align(inputs, attrs):
"Concat" : _mx_concat,
"concat" : _mx_concat,
"LeakyReLU" : _mx_leaky_relu,
"_arange" : _mx_arange,
"SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation,
# vision
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
_reg.register_schedule("reshape_like", schedule_injective)
_reg.register_schedule("full", schedule_injective)
_reg.register_schedule("full_like", schedule_injective)
_reg.register_schedule("arange", schedule_injective)
_reg.register_schedule("cast", schedule_injective)
_reg.register_schedule("strided_slice", schedule_injective)
_reg.register_schedule("slice_like", schedule_injective)
Expand Down
52 changes: 49 additions & 3 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,9 @@ def reshape_like(data, shape_like):
"""Reshapes the input array by the size of another array.
For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
the input array into an output array with the same shape as the second input array.
.. note::
Sizes for both array should be compatible.
Sizes for both array should be compatible.
Parameters
----------
Expand Down Expand Up @@ -249,10 +250,57 @@ def full_like(data, fill_value):
return _make.full_like(data, fill_value)


def arange(start, stop=None, step=1, dtype="float32"):
"""Return evenly spaced values within a given interval.
.. note::
Similar to ``numpy.arange``, when only one argument is given, it is used
as `stop` instead of `start` while `start` takes default value 0.
Warning: Undefined behavior when dtype is incompatible with start/stop/step.
It could lead to different results compared to numpy, MXNet, pytorch, etc.
Parameters
----------
start : tvm.Expr, optional
Start of interval. The interval includes this value. The default start
value is 0.
stop : tvm.Expr
Stop of interval. The interval does not include this value.
step : tvm.Expr, optional
Spacing between values. The default step size is 1.
dtype : str, optional
The target data type.
Returns
-------
result : relay.Expr
The resulting tensor.
Examples
--------
.. code-block:: python
relay.arange(5) = [0, 1, 2, 3, 4]
relay.arange(1, 5) = [1, 2, 3, 4]
relay.arange(1, 5, 1.5) = [1, 2.5, 4]
"""
if stop is None:
stop = start
start = 0
return _make.arange(start, stop, step, dtype)


def where(condition, x, y):
"""Selecting elements from either x or y depending on the value of the
condition.
.. note::
The shape of condition, x, and y needs to be the same.
Parameters
----------
condition : relay.Expr
Expand Down Expand Up @@ -282,8 +330,6 @@ def where(condition, x, y):
condition = [1, 0]
relay.where(conditon, x, y) = [[1, 2], [7, 8]]
Note that the shape of condition, x, and y needs to be the same.
"""
return _make.where(condition, x, y)

Expand Down
57 changes: 57 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,63 @@ and type as the input array.
.set_attr<FTVMCompute>("FTVMCompute", FullLikeCompute)
.set_attr<TOpPattern>("TOpPattern", kElemWise);

// arange operator
TVM_REGISTER_NODE_TYPE(ArangeAttrs);

bool ArangeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 1);
const ArangeAttrs* param = attrs.as<ArangeAttrs>();
IndexExpr num_elem = tvm::cast(tvm::Int(32), tvm::ceil(
tvm::cast(tvm::Float(32), param->stop - param->start) / param->step));
if (const tvm::ir::IntImm* val = num_elem.as<tvm::ir::IntImm>()) {
CHECK_GT(val->value, 0)
<< "Invalid arange attributes (start, stop, step): " << param->start
<< ", " << param->stop << ", " << param->step;
}
reporter->Assign(types[0], TensorTypeNode::make({num_elem}, param->dtype));
return true;
}

Array<Tensor> ArangeCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const ArangeAttrs* param = attrs.as<ArangeAttrs>();
return { topi::arange(param->start, param->stop, param->step, param->dtype) };
}

Expr MakeArange(tvm::Expr start,
tvm::Expr stop,
tvm::Expr step,
DataType dtype) {
auto attrs = make_node<ArangeAttrs>();
attrs->start = std::move(start);
attrs->stop = std::move(stop);
attrs->step = std::move(step);
attrs->dtype = std::move(dtype);
static const Op& op = Op::Get("arange");
return CallNode::make(op, {}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op._make.arange")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 4>(MakeArange, args, rv);
});

RELAY_REGISTER_OP("arange")
.describe(R"code(Returns evenly spaced values within a given interval.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ArangeAttrs")
.set_num_inputs(0)
.set_support_level(3)
.add_type_rel("Arange", ArangeRel)
.set_attr<FTVMCompute>("FTVMCompute", ArangeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

// where operator
bool WhereRel(const Array<Type>& types,
int num_inputs,
Expand Down
60 changes: 41 additions & 19 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,30 +203,51 @@ def test_forward_where():
mx_cond = mx.nd.array(np_cond)
mx_x = mx.nd.array(np_x)
mx_y = mx.nd.array(np_y)
shapes = {'cond': dshape, 'x': dshape, 'y': dshape}
mod = mx.mod.Module(mx_sym, label_names=None, data_names=['cond', 'x', 'y'])
mod.bind(data_shapes=[('cond', dshape), ('x', dshape), ('y', dshape)], for_training=False)
mod.bind(data_shapes=shapes.items(), for_training=False)
mod.init_params()
args, auxs = mod.get_params()
mx_out = mx.nd.where(mx_cond, mx_x, mx_y).asnumpy()
out_shape = dshape
shape_dict = {'cond': dshape, 'x': dshape, 'y': dshape}
new_sym, params = relay.frontend.from_mxnet(mx_sym,
shape_dict,
arg_params=args,
aux_params=auxs)

new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, args, auxs)
for target, ctx in ctx_list():
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(new_sym, target, params=params)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input("cond", tvm.nd.array(np_cond))
m.set_input("x", tvm.nd.array(np_x))
m.set_input("y", tvm.nd.array(np_y))
m.set_input(**params)
m.run()
# get outputs
tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(np_cond, np_x, np_y)
tvm.testing.assert_allclose(op_res.asnumpy(), mx_out)


def test_forward_arange():
def _mx_symbol(F, start, stop, step):
if start is None and step is None:
sym = F.arange(stop)
elif start is None:
sym = F.arange(stop, step=step)
elif step is None:
sym = F.arange(start, stop)
else:
sym = F.arange(start, stop, step)
return sym

def verify(start, stop, step):
ref_res = _mx_symbol(mx.nd, start, stop, step).asnumpy()
mx_sym = _mx_symbol(mx.sym, start, stop, step)
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)()
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res)
verify(0, 20, None)
verify(0, 20, 2)
verify(1, 20, None)
verify(1, 20, 2)
verify(1, 20, 1.5)
verify(1, 20.5, None)
verify(1, 20, 3)
verify(20, 1, -1)
verify(20, 1, -1.5)


if __name__ == '__main__':
Expand All @@ -251,3 +272,4 @@ def test_forward_where():
test_forward_argmax()
test_forward_argmin()
test_forward_where()
test_forward_arange()
35 changes: 35 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,40 @@ def test_infer_type_prelu():
verify_infer_type_prelu((1, 3, 2, 2), None, 1, (1, 3, 2, 2))
verify_infer_type_prelu((1, 2, 2, 3), None, 3, (1, 2, 2, 3))


def test_arange():
def verify_arange(start, stop, step):
dtype = "float32"
if start is None and step is None:
x = relay.arange(stop)
ref_res = np.arange(stop)
elif start is None:
x = relay.arange(stop, step=step)
ref_res = np.arange(stop, step=step)
elif step is None:
x = relay.arange(start, stop)
ref_res = np.arange(start, stop)
else:
x = relay.arange(start, stop, step)
ref_res = np.arange(start, stop, step)

func = relay.Function([], x)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)()
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_arange(None, 20, None)
verify_arange(None, 20, 2)
verify_arange(1, 20, None)
verify_arange(1, 20, 2)
verify_arange(1, 20, 1.5)
verify_arange(1, 20.5, None)
verify_arange(1, 20, 3)
verify_arange(20, 1, -1)
verify_arange(20, 1, -1.5)


if __name__ == "__main__":
test_cast()
test_zeros_ones()
Expand All @@ -480,3 +514,4 @@ def test_infer_type_prelu():
test_squeeze_infer_type()
test_squeeze_bad_axes_infer_type()
test_split_infer_type()
test_arange()
13 changes: 13 additions & 0 deletions topi/include/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,19 @@ inline Tensor tensordot(const Tensor& A,
return compute(output_shape, func, name, tag);
}

inline Tensor arange(const Expr start,
const Expr stop,
const Expr step,
Type dtype,
std::string name = "tensor",
std::string tag = kInjective) {
Expr num_elem = tvm::cast(tvm::Int(32), tvm::ceil(
tvm::cast(tvm::Float(32), stop - start) / step));
Array<Expr> shape;
return compute({num_elem}, [&](const Array<Var>& indices) {
return tvm::cast(dtype, start + step * indices[0]);
}, name, tag);
}

} // namespace topi
#endif // TOPI_TRANSFORM_H_
Loading

0 comments on commit 20ef1e4

Please sign in to comment.