Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numpy compatible dtype inference for tvm.convert and tvm.const #3861

Merged
merged 11 commits into from
Sep 9, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions python/tvm/_ffi/node_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,29 @@
# Node base class
_CLASS_NODE_BASE = None


def _set_class_node_base(cls):
global _CLASS_NODE_BASE
_CLASS_NODE_BASE = cls


def _scalar_type_inference(value):
if hasattr(value, 'dtype'):
dtype = str(value.dtype)
elif isinstance(value, bool):
dtype = 'bool'
elif isinstance(value, float):
# We intentionally convert the float to float32 since it's more common in DL.
dtype = 'float32'
elif isinstance(value, int):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you like to check overflow?

# We intentionally convert the python int to int32 since it's more common in DL.
dtype = 'int32'
else:
raise NotImplementedError('Cannot automatically inference the type.'
' value={}'.format(value))
return dtype


class NodeGeneric(object):
"""Base class for all classes that can be converted to node."""
def asnode(self):
Expand Down Expand Up @@ -86,7 +104,7 @@ def const(value, dtype=None):
value : int or float
The input value

dtype : str
dtype : str or None, optional
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure which one would be better

Suggested change
dtype : str or None, optional
dtype : Optional[str]

The data type.

Returns
Expand All @@ -95,8 +113,5 @@ def const(value, dtype=None):
Constant expression corresponds to the value.
"""
if dtype is None:
if isinstance(value, Integral):
dtype = 'int32'
else:
dtype = 'float32'
dtype = _scalar_type_inference(value)
return _api_internal._const(value, dtype)
7 changes: 5 additions & 2 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ._ffi.base import string_types
from ._ffi.node import register_node, NodeBase
from ._ffi.node import convert_to_node as _convert_to_node
from ._ffi.node_generic import _scalar_type_inference
from ._ffi.function import Function
from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
Expand Down Expand Up @@ -73,22 +74,24 @@ def max_value(dtype):
return _api_internal._max_value(dtype)


def const(value, dtype):
def const(value, dtype=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, why make it optional? As a compiler, we should be more serious about which type we would feed into the IR framework.

"""construct a constant

Parameters
----------
value : number
The content of the constant number.

dtype : str
dtype : str or None, optional
The data type.

Returns
-------
const_val: tvm.Expr
The result expression.
"""
if dtype is None:
dtype = _scalar_type_inference(value)
return _api_internal._const(value, dtype)


Expand Down
18 changes: 18 additions & 0 deletions tests/python/unittest/test_lang_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,30 @@
# specific language governing permissions and limitations
# under the License.
import tvm
import numpy as np

def test_const():
x = tvm.const(1, "int32")
print(x.dtype)
assert x.dtype == tvm.int32
assert isinstance(x, tvm.expr.IntImm)


def test_scalar_dtype_inference():
for data in [True, np.bool(1), np.uint8(1), np.uint16(1), np.uint32(1), np.uint64(1),
np.int8(1), np.int16(1), np.int32(1), np.int64(1),
np.float16(1), np.float32(1), np.float64(1)]:
assert tvm.const(data).dtype == str(np.array(data).dtype)
assert tvm.const(1).dtype == 'int32'
assert tvm.const(1.0).dtype == 'float32'

for data in [True, np.bool(1), np.uint8(1), np.uint16(1), np.uint32(1), np.uint64(1),
np.int8(1), np.int16(1), np.int32(1), np.int64(1),
np.float16(1), np.float32(1), np.float64(1)]:
assert tvm.convert(data).dtype == str(np.array(data).dtype)
assert tvm.convert(1).dtype == 'int32'
assert tvm.convert(1.0).dtype == 'float32'

def test_make():
x = tvm.const(1, "int32")
y = tvm.var("x")
Expand Down Expand Up @@ -175,6 +192,7 @@ def test_equality_string_imm():
test_cast()
test_attr()
test_const()
test_scalar_dtype_inference()
test_make()
test_ir()
test_basic()
Expand Down