Skip to content

Commit

Permalink
Numpy compatible dtype inference for tvm.convert and tvm.const (a…
Browse files Browse the repository at this point in the history
…pache#3861)

* numpy compatible type inference

* update

* try to fix

* fix

* try to fix

* fix lint

* Update nn.h

* cast to int32

* try to fix

* fix again

* retrigger ci
  • Loading branch information
sxjscience authored and wweic committed Sep 16, 2019
1 parent 839c814 commit bebc644
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 45 deletions.
24 changes: 19 additions & 5 deletions python/tvm/_ffi/node_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,23 @@ def _set_class_node_base(cls):
_CLASS_NODE_BASE = cls


def _scalar_type_inference(value):
if hasattr(value, 'dtype'):
dtype = str(value.dtype)
elif isinstance(value, bool):
dtype = 'bool'
elif isinstance(value, float):
# We intentionally convert the float to float32 since it's more common in DL.
dtype = 'float32'
elif isinstance(value, int):
# We intentionally convert the python int to int32 since it's more common in DL.
dtype = 'int32'
else:
raise NotImplementedError('Cannot automatically inference the type.'
' value={}'.format(value))
return dtype


class NodeGeneric(object):
"""Base class for all classes that can be converted to node."""
def asnode(self):
Expand Down Expand Up @@ -86,7 +103,7 @@ def const(value, dtype=None):
value : int or float
The input value
dtype : str
dtype : str or None, optional
The data type.
Returns
Expand All @@ -95,8 +112,5 @@ def const(value, dtype=None):
Constant expression corresponds to the value.
"""
if dtype is None:
if isinstance(value, Integral):
dtype = 'int32'
else:
dtype = 'float32'
dtype = _scalar_type_inference(value)
return _api_internal._const(value, dtype)
7 changes: 5 additions & 2 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ._ffi.base import string_types
from ._ffi.node import register_node, NodeBase
from ._ffi.node import convert_to_node as _convert_to_node
from ._ffi.node_generic import _scalar_type_inference
from ._ffi.function import Function
from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
Expand Down Expand Up @@ -73,22 +74,24 @@ def max_value(dtype):
return _api_internal._max_value(dtype)


def const(value, dtype):
def const(value, dtype=None):
"""construct a constant
Parameters
----------
value : number
The content of the constant number.
dtype : str
dtype : str or None, optional
The data type.
Returns
-------
const_val: tvm.Expr
The result expression.
"""
if dtype is None:
dtype = _scalar_type_inference(value)
return _api_internal._const(value, dtype)


Expand Down
18 changes: 18 additions & 0 deletions tests/python/unittest/test_lang_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,30 @@
# specific language governing permissions and limitations
# under the License.
import tvm
import numpy as np

def test_const():
x = tvm.const(1, "int32")
print(x.dtype)
assert x.dtype == tvm.int32
assert isinstance(x, tvm.expr.IntImm)


def test_scalar_dtype_inference():
for data in [True, np.bool(1), np.uint8(1), np.uint16(1), np.uint32(1), np.uint64(1),
np.int8(1), np.int16(1), np.int32(1), np.int64(1),
np.float16(1), np.float32(1), np.float64(1)]:
assert tvm.const(data).dtype == str(np.array(data).dtype)
assert tvm.const(1).dtype == 'int32'
assert tvm.const(1.0).dtype == 'float32'

for data in [True, np.bool(1), np.uint8(1), np.uint16(1), np.uint32(1), np.uint64(1),
np.int8(1), np.int16(1), np.int32(1), np.int64(1),
np.float16(1), np.float32(1), np.float64(1)]:
assert tvm.convert(data).dtype == str(np.array(data).dtype)
assert tvm.convert(1).dtype == 'int32'
assert tvm.convert(1.0).dtype == 'float32'

def test_make():
x = tvm.const(1, "int32")
y = tvm.var("x")
Expand Down Expand Up @@ -175,6 +192,7 @@ def test_equality_string_imm():
test_cast()
test_attr()
test_const()
test_scalar_dtype_inference()
test_make()
test_ir()
test_basic()
Expand Down
20 changes: 10 additions & 10 deletions topi/include/topi/image/resize.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ inline Tensor resize_nearest_neighbor_nhwc(const Tensor& input,
std::string tag = kInjective) {
Array<Expr> out_shape;
out_shape.push_back(input->shape[0]);
out_shape.push_back(shape[0]);
out_shape.push_back(shape[1]);
out_shape.push_back(cast(Int(32), shape[0]));
out_shape.push_back(cast(Int(32), shape[1]));
out_shape.push_back(input->shape[3]);

return compute(
Expand Down Expand Up @@ -132,8 +132,8 @@ inline Tensor resize_nearest_neighbor_nchw(const Tensor& input,
Array<Expr> out_shape;
out_shape.push_back(input->shape[0]);
out_shape.push_back(input->shape[1]);
out_shape.push_back(shape[0]);
out_shape.push_back(shape[1]);
out_shape.push_back(cast(Int(32), shape[0]));
out_shape.push_back(cast(Int(32), shape[1]));

return compute(
out_shape, [&](const Array<Var>& indices) {
Expand Down Expand Up @@ -166,8 +166,8 @@ inline Tensor resize_nearest_neighbor_nchwc(const Tensor& input,
Array<Expr> out_shape;
out_shape.push_back(input->shape[0]);
out_shape.push_back(input->shape[1]);
out_shape.push_back(shape[0]);
out_shape.push_back(shape[1]);
out_shape.push_back(cast(Int(32), shape[0]));
out_shape.push_back(cast(Int(32), shape[1]));
out_shape.push_back(input->shape[4]);

return compute(
Expand Down Expand Up @@ -233,8 +233,8 @@ inline Tensor resize_bilinear_nhwc(const Tensor& input,
std::string tag = kInjective) {
Array<Expr> out_shape;
out_shape.push_back(input->shape[0]);
out_shape.push_back(shape[0]);
out_shape.push_back(shape[1]);
out_shape.push_back(cast(Int(32), shape[0]));
out_shape.push_back(cast(Int(32), shape[1]));
out_shape.push_back(input->shape[3]);

Expr cone = make_const(Int(32), 1);
Expand Down Expand Up @@ -311,8 +311,8 @@ inline Tensor resize_bilinear_nchw(const Tensor& input,
Array<Expr> out_shape;
out_shape.push_back(input->shape[0]);
out_shape.push_back(input->shape[1]);
out_shape.push_back(shape[0]);
out_shape.push_back(shape[1]);
out_shape.push_back(cast(Int(32), shape[0]));
out_shape.push_back(cast(Int(32), shape[1]));

Expr cone = make_const(Int(32), 1);

Expand Down
22 changes: 15 additions & 7 deletions topi/include/topi/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,20 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
CHECK_GE(pad_before.size(), 1);
CHECK_EQ(pad_before.size(), pad_after.size());
tvm::Array<tvm::Expr> output_shape;
tvm::Array<tvm::Expr> pad_before_int32;
tvm::Array<tvm::Expr> pad_after_int32;
for (const auto &ele : pad_before) {
pad_before_int32.push_back(tvm::cast(tvm::Int(32), ele));
}
for (const auto &ele : pad_after) {
pad_after_int32.push_back(tvm::cast(tvm::Int(32), ele));
}
for (size_t i = 0; i < t->shape.size(); ++i) {
if (i >= pad_before.size()) {
output_shape.push_back(t->shape[i]);
} else {
output_shape.push_back(
tvm::ir::Simplify(t->shape[i] + pad_before[i] + pad_after[i]));
tvm::ir::Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i]));
}
}

Expand All @@ -199,18 +207,18 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
tvm::Array<tvm::Expr> indices;
tvm::Array<tvm::Expr> sel;
for (size_t i = 0; i < t->shape.size(); ++i) {
if (i >= pad_before.size()) {
if (i >= pad_before_int32.size()) {
indices.push_back(ovars[i]);
continue;
}
if (!topi::detail::EqualCheck(pad_before[i], 0)) {
sel.push_back(ovars[i] >= pad_before[i]);
indices.push_back(ovars[i] - pad_before[i]);
if (!topi::detail::EqualCheck(pad_before_int32[i], 0)) {
sel.push_back(ovars[i] >= pad_before_int32[i]);
indices.push_back(ovars[i] - pad_before_int32[i]);
} else {
indices.push_back(ovars[i]);
}
if (!topi::detail::EqualCheck(pad_after[i], 0)) {
sel.push_back(tvm::ir::Simplify(ovars[i] < pad_before[i] + t->shape[i]));
if (!topi::detail::EqualCheck(pad_after_int32[i], 0)) {
sel.push_back(tvm::ir::Simplify(ovars[i] < pad_before_int32[i] + t->shape[i]));
}
}
if (sel.size() != 0) {
Expand Down
2 changes: 1 addition & 1 deletion topi/include/topi/nn/dilate.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ inline Tensor dilate(const Tensor& x,
Array<Expr> out_shape;
for (size_t i = 0; i < n; ++i) {
out_shape.push_back(tvm::ir::Simplify(
(x->shape[i] - 1) * strides[i] + 1));
(x->shape[i] - 1) * cast(Int(32), strides[i] + 1)));
}

return tvm::compute(
Expand Down
36 changes: 18 additions & 18 deletions topi/include/topi/nn/pooling.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,18 @@ inline Tensor pool_impl(const Tensor& x,
CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements";
CHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements";

auto kernel_height = kernel_size[0];
auto kernel_width = kernel_size[1];
auto stride_height = stride_size[0];
auto stride_width = stride_size[1];
auto kernel_height = cast(Int(32), kernel_size[0]);
auto kernel_width = cast(Int(32), kernel_size[1]);
auto stride_height = cast(Int(32), stride_size[0]);
auto stride_width = cast(Int(32), stride_size[1]);

auto height = x->shape[height_axis];
auto width = x->shape[width_axis];

auto pad_top = padding_size[0];
auto pad_left = padding_size[1];
auto pad_bottom = padding_size[2];
auto pad_right = padding_size[3];
auto pad_top = cast(Int(32), padding_size[0]);
auto pad_left = cast(Int(32), padding_size[1]);
auto pad_bottom = cast(Int(32), padding_size[2]);
auto pad_right = cast(Int(32), padding_size[3]);

if (ceil_mode) {
// Additional padding to ensure we do ceil instead of floor when
Expand Down Expand Up @@ -179,18 +179,18 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements";
CHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements";

auto kernel_height = kernel_size[0];
auto kernel_width = kernel_size[1];
auto stride_height = stride_size[0];
auto stride_width = stride_size[1];
auto kernel_height = cast(Int(32), kernel_size[0]);
auto kernel_width = cast(Int(32), kernel_size[1]);
auto stride_height = cast(Int(32), stride_size[0]);
auto stride_width = cast(Int(32), stride_size[1]);

auto height = x->shape[height_axis];
auto width = x->shape[width_axis];

auto pad_top = padding_size[0];
auto pad_left = padding_size[1];
auto pad_bottom = padding_size[2];
auto pad_right = padding_size[3];
auto pad_top = cast(Int(32), padding_size[0]);
auto pad_left = cast(Int(32), padding_size[1]);
auto pad_bottom = cast(Int(32), padding_size[2]);
auto pad_right = cast(Int(32), padding_size[3]);

if (ceil_mode) {
// Additional padding to ensure we do ceil instead of floor when
Expand Down Expand Up @@ -471,8 +471,8 @@ inline Tensor adaptive_pool_impl(const Tensor& x,
auto height = x->shape[height_axis];
auto width = x->shape[width_axis];

auto out_height = output_size[0];
auto out_width = output_size[1];
auto out_height = cast(Int(32), output_size[0]);
auto out_width = cast(Int(32), output_size[1]);
Array<Expr> out_shape = x->shape;
out_shape.Set(height_axis, out_height);
out_shape.Set(width_axis, out_width);
Expand Down
9 changes: 7 additions & 2 deletions topi/include/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,14 @@ inline Tensor reshape(const Tensor& x,
std::string name = "T_reshape",
std::string tag = kInjective) {
auto x_shape = x->shape;
Array<Expr> newshape_int32;

for (const auto &ele : newshape) {
newshape_int32.push_back(cast(Int(32), ele));
}
return compute(
newshape, [&](const Array<Var>& indices) {
return x(UnravelIndex(RavelIndex(Array<Expr>{indices.begin(), indices.end()}, newshape),
newshape_int32, [&](const Array<Var>& indices) {
return x(UnravelIndex(RavelIndex(Array<Expr>{indices.begin(), indices.end()}, newshape_int32),
x_shape));
}, name, tag);
}
Expand Down

0 comments on commit bebc644

Please sign in to comment.