From ae0a0627b644fc07247a82f432e939810c63b382 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Sat, 8 Aug 2020 18:26:36 -0700 Subject: [PATCH] [FIX] Verify that tensor reshape is valid. (#6215) --- src/relay/op/tensor/transform.cc | 31 +++++++++++++++++++ tests/python/relay/test_op_level3.py | 9 ++++++ .../relay/test_pass_combine_parallel_dense.py | 4 +-- 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 9427dedfe3fa..79a8da4fea75 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -25,6 +25,7 @@ #include #include +#include #include #include #include @@ -581,6 +582,36 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, oshape.Set(infer_idx, infer_dim); } + // Verify that the sum of dimensions in the output shape is the sum of + // dimensions in the input shape + bool found_dynamic = false; + int64_t oshape_sum = 1; + for (auto& x : oshape) { + // Check if we have a dynamic shape. If we do, we can't verify if the + // reshape is valid. Dynamic shapes are marker by using Any, but can also + // occur from SizeVar's. In the case of SizeVar, the shape expression can + // be an AST. We can't easily check if we have an AST because of a ShapeVar + // or some other reason, so our check for dynamic shape is just if we can + // convert the shape to in integer or not. + if (!x->IsInstance()) { + found_dynamic = true; + break; + } + oshape_sum *= Downcast(x)->value; + } + int64_t data_shape_sum = 1; + for (auto& x : data_shape) { + if (!x->IsInstance()) { + found_dynamic = true; + break; + } + data_shape_sum *= Downcast(x)->value; + } + if (!found_dynamic) { + CHECK_EQ(oshape_sum, data_shape_sum) + << "Input tensor shape and reshaped shape are not compatible"; + } + if (param->reverse) { reporter->Assign(types[1], TensorType(Array(oshape.rbegin(), oshape.rend()), data->dtype)); diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 76f10d6c1a18..db45fcbef6cb 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -21,6 +21,7 @@ import tvm from tvm import te from tvm import relay +from tvm.error import TVMError from tvm.relay import create_executor, transform from tvm.relay.testing import ctx_list, check_grad, run_infer_type @@ -282,6 +283,13 @@ def verify_reshape(shape, newshape, oshape): verify_reshape((2, 3, 4), (2, -4, -1, 3, -2), (2, 1, 3, 4)) +def test_reshape_fail(): + with pytest.raises(TVMError) as reshape_err: + x = relay.var("x", relay.TensorType([2,3], "float32")) + z = relay.reshape(x, [7]) + zz = run_infer_type(z) + + def test_reshape_like_infer_type(): # concrete shape x = relay.var("x", relay.TensorType((1, 2, 3), "float32")) @@ -1070,6 +1078,7 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ test_transpose() test_reshape_infer_type() test_reshape() + test_reshape_fail() test_reshape_like_infer_type() test_reshape_like() test_take_infer_type() diff --git a/tests/python/relay/test_pass_combine_parallel_dense.py b/tests/python/relay/test_pass_combine_parallel_dense.py index a8d7f11c471a..535f97a39d08 100644 --- a/tests/python/relay/test_pass_combine_parallel_dense.py +++ b/tests/python/relay/test_pass_combine_parallel_dense.py @@ -187,7 +187,7 @@ def check(i, j, k, scale1, scale2, newshape): tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True) check(3, 5, 4, 0.5, 0.25, (1, 1, 15)) - check(100, 200, 300, 0.5, 0.25, (1, 1, 200)) + check(100, 200, 300, 0.5, 0.25, (1, 1, 20000)) def test_combine_parallel_dense_flat(): @@ -369,7 +369,7 @@ def check(i, j, k, scale1, scale2, newshape1, newshape2): tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True) check(3, 5, 4, 0.5, 0.25, (1, 1, 15), (1, 1, 30)) - check(100, 200, 300, 0.5, 0.25, (1, 1, 200), (1, 1, 400)) + check(100, 200, 300, 0.5, 0.25, (1, 1, 20000), (1, 1, 40000)) if __name__ == "__main__":