Skip to content

Commit

Permalink
Update precision in the ONNX strided_slice, update precision of ToSca…
Browse files Browse the repository at this point in the history
…lar (apache#6272)

* Update precision in the ONNX strided_slice, update precision of ToScalar

* fix tests
  • Loading branch information
Matthew Brookhart authored and Trevor Morris committed Aug 26, 2020
1 parent 74f764d commit 4a03e1d
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 12 deletions.
8 changes: 4 additions & 4 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,8 +1045,8 @@ def _impl_v1(cls, inputs, attr, params):
end = list(attr['ends'])

return _op.strided_slice(inputs[0],
begin=_expr.const(begin, dtype="int32"),
end=_expr.const(end, dtype="int32"))
begin=_expr.const(begin, dtype="int64"),
end=_expr.const(end, dtype="int64"))

@classmethod
def _impl_v10(cls, inputs, attr, params):
Expand All @@ -1063,8 +1063,8 @@ def _impl_v10(cls, inputs, attr, params):
starts = new_starts
ends = new_ends
return _op.strided_slice(inputs[0],
begin=_expr.const(starts, dtype="int32"),
end=_expr.const(ends, dtype="int32"))
begin=_expr.const(starts, dtype="int64"),
end=_expr.const(ends, dtype="int64"))


class Gather(OnnxOpConverter):
Expand Down
6 changes: 3 additions & 3 deletions src/relay/transforms/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) {
* \param i element index
* \return Converted scalar value.
*/
static inline double ToScalar(const runtime::NDArray& array, size_t i = 0) {
static inline long double ToScalar(const runtime::NDArray& array, size_t i = 0) {
if (array->dtype.code == kDLInt) {
if (array->dtype.bits == 8) {
return reinterpret_cast<int8_t*>(array->data)[i];
Expand Down Expand Up @@ -423,8 +423,8 @@ static inline Array<Integer> ToVector(const runtime::NDArray& array) {
size_t len = array.Shape().front();
Array<Integer> out;
for (size_t i = 0; i < len; ++i) {
double elem_val = ToScalar(array, i);
out.push_back(Integer(static_cast<int>(elem_val)));
long double elem_val = ToScalar(array, i);
out.push_back(Integer(IntImm(DataType::Int(32), static_cast<int64_t>(elem_val))));
}
return out;
}
Expand Down
11 changes: 6 additions & 5 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,15 +478,15 @@ def _test_slice_iteration_v10(indata, outdata, starts, ends, axes=None):
inputs = [
helper.make_tensor_value_info("data", TensorProto.FLOAT,
list(indata.shape)),
helper.make_tensor_value_info("starts", TensorProto.INT32,
helper.make_tensor_value_info("starts", TensorProto.INT64,
list(starts.shape)),
helper.make_tensor_value_info("ends", TensorProto.INT32,
helper.make_tensor_value_info("ends", TensorProto.INT64,
list(ends.shape))
]
initializer = [
helper.make_tensor("starts", TensorProto.INT32, list(starts.shape),
helper.make_tensor("starts", TensorProto.INT64, list(starts.shape),
starts),
helper.make_tensor("ends", TensorProto.INT32, list(ends.shape), ends)
helper.make_tensor("ends", TensorProto.INT64, list(ends.shape), ends)
]

if axes:
Expand Down Expand Up @@ -534,7 +534,8 @@ def test_slice():
_test_slice_iteration_v10(x, x[0:3, 0:10], (0, 0), (3, 10), (0, 1))
_test_slice_iteration_v10(x, x[:, :, 3:4], (0, 0, 3), (20, 10, 4))
_test_slice_iteration_v10(x, x[:, 1:1000], (1), (1000), (1))
_test_slice_iteration_v10(x, x[:, 0:-1], (0), (-1), (1))
x = np.random.randn(1, 1, 1, 128).astype(np.float32)
_test_slice_iteration_v10(x, x, (0, 0), (9223372036854775807, 9223372036854775807), (0, 3))


def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs):
Expand Down

0 comments on commit 4a03e1d

Please sign in to comment.