-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Numpy compatible dtype inference for tvm.convert
and tvm.const
#3861
Changes from 10 commits
46ce36c
7d0b131
decacff
152930d
ed033ab
1e2042d
37b8687
2986a35
b0eb311
43c0d1d
ed332fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -30,6 +30,23 @@ def _set_class_node_base(cls): | |||||
_CLASS_NODE_BASE = cls | ||||||
|
||||||
|
||||||
def _scalar_type_inference(value): | ||||||
if hasattr(value, 'dtype'): | ||||||
dtype = str(value.dtype) | ||||||
elif isinstance(value, bool): | ||||||
dtype = 'bool' | ||||||
elif isinstance(value, float): | ||||||
# We intentionally convert the float to float32 since it's more common in DL. | ||||||
dtype = 'float32' | ||||||
elif isinstance(value, int): | ||||||
# We intentionally convert the python int to int32 since it's more common in DL. | ||||||
dtype = 'int32' | ||||||
else: | ||||||
raise NotImplementedError('Cannot automatically inference the type.' | ||||||
' value={}'.format(value)) | ||||||
return dtype | ||||||
|
||||||
|
||||||
class NodeGeneric(object): | ||||||
"""Base class for all classes that can be converted to node.""" | ||||||
def asnode(self): | ||||||
|
@@ -86,7 +103,7 @@ def const(value, dtype=None): | |||||
value : int or float | ||||||
The input value | ||||||
|
||||||
dtype : str | ||||||
dtype : str or None, optional | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure which one would be better
Suggested change
|
||||||
The data type. | ||||||
|
||||||
Returns | ||||||
|
@@ -95,8 +112,5 @@ def const(value, dtype=None): | |||||
Constant expression corresponds to the value. | ||||||
""" | ||||||
if dtype is None: | ||||||
if isinstance(value, Integral): | ||||||
dtype = 'int32' | ||||||
else: | ||||||
dtype = 'float32' | ||||||
dtype = _scalar_type_inference(value) | ||||||
return _api_internal._const(value, dtype) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ | |
from ._ffi.base import string_types | ||
from ._ffi.node import register_node, NodeBase | ||
from ._ffi.node import convert_to_node as _convert_to_node | ||
from ._ffi.node_generic import _scalar_type_inference | ||
from ._ffi.function import Function | ||
from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs | ||
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func | ||
|
@@ -73,22 +74,24 @@ def max_value(dtype): | |
return _api_internal._max_value(dtype) | ||
|
||
|
||
def const(value, dtype): | ||
def const(value, dtype=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey, why make it optional? As a compiler, we should be more serious about which type we would feed into the IR framework. |
||
"""construct a constant | ||
|
||
Parameters | ||
---------- | ||
value : number | ||
The content of the constant number. | ||
|
||
dtype : str | ||
dtype : str or None, optional | ||
The data type. | ||
|
||
Returns | ||
------- | ||
const_val: tvm.Expr | ||
The result expression. | ||
""" | ||
if dtype is None: | ||
dtype = _scalar_type_inference(value) | ||
return _api_internal._const(value, dtype) | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -97,8 +97,8 @@ inline Tensor resize_nearest_neighbor_nhwc(const Tensor& input, | |
std::string tag = kInjective) { | ||
Array<Expr> out_shape; | ||
out_shape.push_back(input->shape[0]); | ||
out_shape.push_back(shape[0]); | ||
out_shape.push_back(shape[1]); | ||
out_shape.push_back(cast(Int(32), shape[0])); | ||
out_shape.push_back(cast(Int(32), shape[1])); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so these shape could be int64 when passed in? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, they can be int64 because we only constrain them to be Expr. |
||
out_shape.push_back(input->shape[3]); | ||
|
||
return compute( | ||
|
@@ -132,8 +132,8 @@ inline Tensor resize_nearest_neighbor_nchw(const Tensor& input, | |
Array<Expr> out_shape; | ||
out_shape.push_back(input->shape[0]); | ||
out_shape.push_back(input->shape[1]); | ||
out_shape.push_back(shape[0]); | ||
out_shape.push_back(shape[1]); | ||
out_shape.push_back(cast(Int(32), shape[0])); | ||
out_shape.push_back(cast(Int(32), shape[1])); | ||
|
||
return compute( | ||
out_shape, [&](const Array<Var>& indices) { | ||
|
@@ -166,8 +166,8 @@ inline Tensor resize_nearest_neighbor_nchwc(const Tensor& input, | |
Array<Expr> out_shape; | ||
out_shape.push_back(input->shape[0]); | ||
out_shape.push_back(input->shape[1]); | ||
out_shape.push_back(shape[0]); | ||
out_shape.push_back(shape[1]); | ||
out_shape.push_back(cast(Int(32), shape[0])); | ||
out_shape.push_back(cast(Int(32), shape[1])); | ||
out_shape.push_back(input->shape[4]); | ||
|
||
return compute( | ||
|
@@ -233,8 +233,8 @@ inline Tensor resize_bilinear_nhwc(const Tensor& input, | |
std::string tag = kInjective) { | ||
Array<Expr> out_shape; | ||
out_shape.push_back(input->shape[0]); | ||
out_shape.push_back(shape[0]); | ||
out_shape.push_back(shape[1]); | ||
out_shape.push_back(cast(Int(32), shape[0])); | ||
out_shape.push_back(cast(Int(32), shape[1])); | ||
out_shape.push_back(input->shape[3]); | ||
|
||
Expr cone = make_const(Int(32), 1); | ||
|
@@ -311,8 +311,8 @@ inline Tensor resize_bilinear_nchw(const Tensor& input, | |
Array<Expr> out_shape; | ||
out_shape.push_back(input->shape[0]); | ||
out_shape.push_back(input->shape[1]); | ||
out_shape.push_back(shape[0]); | ||
out_shape.push_back(shape[1]); | ||
out_shape.push_back(cast(Int(32), shape[0])); | ||
out_shape.push_back(cast(Int(32), shape[1])); | ||
|
||
Expr cone = make_const(Int(32), 1); | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -182,12 +182,20 @@ inline tvm::Tensor pad(const tvm::Tensor& t, | |
CHECK_GE(pad_before.size(), 1); | ||
CHECK_EQ(pad_before.size(), pad_after.size()); | ||
tvm::Array<tvm::Expr> output_shape; | ||
tvm::Array<tvm::Expr> pad_before_int32; | ||
tvm::Array<tvm::Expr> pad_after_int32; | ||
for (const auto &ele : pad_before) { | ||
pad_before_int32.push_back(tvm::cast(tvm::Int(32), ele)); | ||
} | ||
for (const auto &ele : pad_after) { | ||
pad_after_int32.push_back(tvm::cast(tvm::Int(32), ele)); | ||
} | ||
for (size_t i = 0; i < t->shape.size(); ++i) { | ||
if (i >= pad_before.size()) { | ||
output_shape.push_back(t->shape[i]); | ||
} else { | ||
output_shape.push_back( | ||
tvm::ir::Simplify(t->shape[i] + pad_before[i] + pad_after[i])); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I find that there are some data type mismatch problems in the implementation of TOPI. For example, here, |
||
tvm::ir::Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i])); | ||
} | ||
} | ||
|
||
|
@@ -199,18 +207,18 @@ inline tvm::Tensor pad(const tvm::Tensor& t, | |
tvm::Array<tvm::Expr> indices; | ||
tvm::Array<tvm::Expr> sel; | ||
for (size_t i = 0; i < t->shape.size(); ++i) { | ||
if (i >= pad_before.size()) { | ||
if (i >= pad_before_int32.size()) { | ||
indices.push_back(ovars[i]); | ||
continue; | ||
} | ||
if (!topi::detail::EqualCheck(pad_before[i], 0)) { | ||
sel.push_back(ovars[i] >= pad_before[i]); | ||
indices.push_back(ovars[i] - pad_before[i]); | ||
if (!topi::detail::EqualCheck(pad_before_int32[i], 0)) { | ||
sel.push_back(ovars[i] >= pad_before_int32[i]); | ||
indices.push_back(ovars[i] - pad_before_int32[i]); | ||
} else { | ||
indices.push_back(ovars[i]); | ||
} | ||
if (!topi::detail::EqualCheck(pad_after[i], 0)) { | ||
sel.push_back(tvm::ir::Simplify(ovars[i] < pad_before[i] + t->shape[i])); | ||
if (!topi::detail::EqualCheck(pad_after_int32[i], 0)) { | ||
sel.push_back(tvm::ir::Simplify(ovars[i] < pad_before_int32[i] + t->shape[i])); | ||
} | ||
} | ||
if (sel.size() != 0) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -73,18 +73,18 @@ inline Tensor pool_impl(const Tensor& x, | |
CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements"; | ||
CHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements"; | ||
|
||
auto kernel_height = kernel_size[0]; | ||
auto kernel_width = kernel_size[1]; | ||
auto stride_height = stride_size[0]; | ||
auto stride_width = stride_size[1]; | ||
auto kernel_height = cast(Int(32), kernel_size[0]); | ||
auto kernel_width = cast(Int(32), kernel_size[1]); | ||
auto stride_height = cast(Int(32), stride_size[0]); | ||
auto stride_width = cast(Int(32), stride_size[1]); | ||
|
||
auto height = x->shape[height_axis]; | ||
auto width = x->shape[width_axis]; | ||
|
||
auto pad_top = padding_size[0]; | ||
auto pad_left = padding_size[1]; | ||
auto pad_bottom = padding_size[2]; | ||
auto pad_right = padding_size[3]; | ||
auto pad_top = cast(Int(32), padding_size[0]); | ||
auto pad_left = cast(Int(32), padding_size[1]); | ||
auto pad_bottom = cast(Int(32), padding_size[2]); | ||
auto pad_right = cast(Int(32), padding_size[3]); | ||
|
||
if (ceil_mode) { | ||
// Additional padding to ensure we do ceil instead of floor when | ||
|
@@ -179,18 +179,18 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, | |
CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements"; | ||
CHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements"; | ||
|
||
auto kernel_height = kernel_size[0]; | ||
auto kernel_width = kernel_size[1]; | ||
auto stride_height = stride_size[0]; | ||
auto stride_width = stride_size[1]; | ||
auto kernel_height = cast(Int(32), kernel_size[0]); | ||
auto kernel_width = cast(Int(32), kernel_size[1]); | ||
auto stride_height = cast(Int(32), stride_size[0]); | ||
auto stride_width = cast(Int(32), stride_size[1]); | ||
|
||
auto height = x->shape[height_axis]; | ||
auto width = x->shape[width_axis]; | ||
|
||
auto pad_top = padding_size[0]; | ||
auto pad_left = padding_size[1]; | ||
auto pad_bottom = padding_size[2]; | ||
auto pad_right = padding_size[3]; | ||
auto pad_top = cast(Int(32), padding_size[0]); | ||
auto pad_left = cast(Int(32), padding_size[1]); | ||
auto pad_bottom = cast(Int(32), padding_size[2]); | ||
auto pad_right = cast(Int(32), padding_size[3]); | ||
|
||
if (ceil_mode) { | ||
// Additional padding to ensure we do ceil instead of floor when | ||
|
@@ -471,8 +471,8 @@ inline Tensor adaptive_pool_impl(const Tensor& x, | |
auto height = x->shape[height_axis]; | ||
auto width = x->shape[width_axis]; | ||
|
||
auto out_height = output_size[0]; | ||
auto out_width = output_size[1]; | ||
auto out_height = cast(Int(32), output_size[0]); | ||
auto out_width = cast(Int(32), output_size[1]); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fyi @anijain2305 Is it good in quantization? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this is good. We can cast back to x->dtype if it is not FP32 just before the divide factor. |
||
Array<Expr> out_shape = x->shape; | ||
out_shape.Set(height_axis, out_height); | ||
out_shape.Set(width_axis, out_width); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you like to check overflow?