diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 34e3fdd4e760..453058556880 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -833,7 +833,7 @@ def verify(data_shape, out_shape, begin, end): def test_forward_convolution(): def verify(data_shape, kernel_size, stride, pad, num_filter): - weight_shape=(num_filter,1,) + kernel_size + weight_shape=(num_filter, data_shape[1],) + kernel_size x = np.random.uniform(size=data_shape).astype("float32") weight = np.random.uniform(size=weight_shape).astype("float32") bias = np.random.uniform(size=num_filter).astype("float32") @@ -852,11 +852,17 @@ def verify(data_shape, kernel_size, stride, pad, num_filter): tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3) verify(data_shape=(1,1,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4) + verify(data_shape=(20,1,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4) + verify(data_shape=(1,8,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4) + verify(data_shape=(20,8,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4) verify(data_shape=(1, 1, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2) + verify(data_shape=(20, 1, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2) + verify(data_shape=(1, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2) + verify(data_shape=(20, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2) def test_forward_deconvolution(): def verify(data_shape, kernel_size, stride, pad, num_filter): - weight_shape=(1, num_filter) + kernel_size + weight_shape=(data_shape[1], num_filter) + kernel_size x = np.random.uniform(size=data_shape).astype("float32") weight = np.random.uniform(size=weight_shape).astype("float32") bias = np.random.uniform(size=num_filter).astype("float32") @@ -875,7 +881,13 @@ def verify(data_shape, kernel_size, stride, pad, num_filter): tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3) verify(data_shape=(1,1,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4) + verify(data_shape=(20,1,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4) + verify(data_shape=(1,8,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4) + verify(data_shape=(20,8,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4) verify(data_shape=(1, 1, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2) + verify(data_shape=(20, 1, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2) + verify(data_shape=(1, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2) + verify(data_shape=(20, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2) if __name__ == '__main__':