-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Numpy compatible dtype inference for tvm.convert
and tvm.const
#3861
Changes from 3 commits
46ce36c
7d0b131
decacff
152930d
ed033ab
1e2042d
37b8687
2986a35
b0eb311
43c0d1d
ed332fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -25,11 +25,29 @@ | |||||
# Node base class | ||||||
_CLASS_NODE_BASE = None | ||||||
|
||||||
|
||||||
def _set_class_node_base(cls): | ||||||
global _CLASS_NODE_BASE | ||||||
_CLASS_NODE_BASE = cls | ||||||
|
||||||
|
||||||
def _scalar_type_inference(value): | ||||||
if hasattr(value, 'dtype'): | ||||||
dtype = str(value.dtype) | ||||||
elif isinstance(value, bool): | ||||||
dtype = 'bool' | ||||||
elif isinstance(value, float): | ||||||
# We intentionally convert the float to float32 since it's more common in DL. | ||||||
dtype = 'float32' | ||||||
elif isinstance(value, int): | ||||||
# We intentionally convert the python int to int32 since it's more common in DL. | ||||||
dtype = 'int32' | ||||||
else: | ||||||
raise NotImplementedError('Cannot automatically inference the type.' | ||||||
' value={}'.format(value)) | ||||||
return dtype | ||||||
|
||||||
|
||||||
class NodeGeneric(object): | ||||||
"""Base class for all classes that can be converted to node.""" | ||||||
def asnode(self): | ||||||
|
@@ -86,7 +104,7 @@ def const(value, dtype=None): | |||||
value : int or float | ||||||
The input value | ||||||
|
||||||
dtype : str | ||||||
dtype : str or None, optional | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure which one would be better
Suggested change
|
||||||
The data type. | ||||||
|
||||||
Returns | ||||||
|
@@ -95,8 +113,5 @@ def const(value, dtype=None): | |||||
Constant expression corresponds to the value. | ||||||
""" | ||||||
if dtype is None: | ||||||
if isinstance(value, Integral): | ||||||
dtype = 'int32' | ||||||
else: | ||||||
dtype = 'float32' | ||||||
dtype = _scalar_type_inference(value) | ||||||
return _api_internal._const(value, dtype) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ | |
from ._ffi.base import string_types | ||
from ._ffi.node import register_node, NodeBase | ||
from ._ffi.node import convert_to_node as _convert_to_node | ||
from ._ffi.node_generic import _scalar_type_inference | ||
from ._ffi.function import Function | ||
from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs | ||
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func | ||
|
@@ -73,22 +74,24 @@ def max_value(dtype): | |
return _api_internal._max_value(dtype) | ||
|
||
|
||
def const(value, dtype): | ||
def const(value, dtype=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey, why make it optional? As a compiler, we should be more serious about which type we would feed into the IR framework. |
||
"""construct a constant | ||
|
||
Parameters | ||
---------- | ||
value : number | ||
The content of the constant number. | ||
|
||
dtype : str | ||
dtype : str or None, optional | ||
The data type. | ||
|
||
Returns | ||
------- | ||
const_val: tvm.Expr | ||
The result expression. | ||
""" | ||
if dtype is None: | ||
dtype = _scalar_type_inference(value) | ||
return _api_internal._const(value, dtype) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you like to check overflow?