From 92d7eafee612c46005690db91614d56a139a65e2 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Tue, 24 Mar 2020 23:59:55 +0000 Subject: [PATCH] Improve test --- python/tvm/relay/prelude.py | 3 ++- tests/python/relay/test_adt.py | 17 +---------------- 2 files changed, 3 insertions(+), 17 deletions(-) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index a9e5a684d6248..42ede20d01682 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -39,7 +39,8 @@ def __init__(self, prelude, dtype, shape): def get_name(self, canonical): """Get name corresponding to the canonical name""" - shape_str = str(self.shape).replace('[', '').replace(']', '').replace(', ', '_') + shape_str = str(self.shape).replace('[', '').replace(']', '')\ + .replace('(', '').replace(')', '').replace(', ', '_') if len(shape_str) == 0: shape_str = "scalar" if canonical == 'tensor_t': diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 2a114d82d69b5..2dc25ecdffaac 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -19,7 +19,7 @@ from tvm import relay from tvm.relay.backend.interpreter import ConstructorValue from tvm.relay import create_executor -from tvm.relay.prelude import Prelude +from tvm.relay.prelude import Prelude, StaticTensorArrayOps from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr import numpy as np @@ -984,7 +984,6 @@ def test_static_tensor_take(): def run(dtype, shape): mod = tvm.IRModule() p = Prelude(mod) - from tvm.relay.prelude import StaticTensorArrayOps static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) static_tensor_array_ops.register() @@ -1006,7 +1005,6 @@ def test_static_tensor_concatenate(): def run(dtype, shape): mod = tvm.IRModule() p = Prelude(mod) - from tvm.relay.prelude import StaticTensorArrayOps static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) static_tensor_array_ops.register() @@ -1028,7 +1026,6 @@ def run(dtype, shape): x = relay.var('x') mod = tvm.IRModule() p = Prelude(mod) - from tvm.relay.prelude import StaticTensorArrayOps static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) static_tensor_array_ops.register() @@ -1046,7 +1043,6 @@ def run(dtype, shape): x = relay.var('x') mod = tvm.IRModule() p = Prelude(mod) - from tvm.relay.prelude import StaticTensorArrayOps static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) static_tensor_array_ops.register() tensor_constructor = p.get_name_static('tensor_constructor', dtype, shape) @@ -1057,7 +1053,6 @@ def test_static_tensor_array_read(): def run(dtype, shape): mod = tvm.IRModule() p = Prelude(mod) - from tvm.relay.prelude import StaticTensorArrayOps static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) static_tensor_array_ops.register() @@ -1096,7 +1091,6 @@ def test_static_tensor_array_write(): def run(dtype, shape): mod = tvm.IRModule() p = Prelude(mod) - from tvm.relay.prelude import StaticTensorArrayOps static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) static_tensor_array_ops.register() @@ -1122,7 +1116,6 @@ def test_static_tensor_array_unstack(): def run(dtype, shape): mod = tvm.IRModule() p = Prelude(mod) - from tvm.relay.prelude import StaticTensorArrayOps static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) static_tensor_array_ops.register() @@ -1139,7 +1132,6 @@ def test_static_tensor_array_scatter(): def run(dtype, shape): mod = tvm.IRModule() p = Prelude(mod) - from tvm.relay.prelude import StaticTensorArrayOps static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) static_tensor_array_ops.register() @@ -1191,7 +1183,6 @@ def test_static_tensor_array_split(): def run(dtype, shape): mod = tvm.IRModule() p = Prelude(mod) - from tvm.relay.prelude import StaticTensorArrayOps static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) static_tensor_array_ops.register() @@ -1245,7 +1236,6 @@ def test_static_tensor_array_concat(): def run(dtype, shape): mod = tvm.IRModule() p = Prelude(mod) - from tvm.relay.prelude import StaticTensorArrayOps static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) static_tensor_array_ops.register() @@ -1271,7 +1261,6 @@ def test_static_tensor_array_gather(): def run(dtype, shape): mod = tvm.IRModule() p = Prelude(mod) - from tvm.relay.prelude import StaticTensorArrayOps static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) static_tensor_array_ops.register() @@ -1298,7 +1287,6 @@ def test_static_tensor_array_stack(): def run(dtype, shape): mod = tvm.IRModule() p = Prelude(mod) - from tvm.relay.prelude import StaticTensorArrayOps static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) static_tensor_array_ops.register() @@ -1323,7 +1311,6 @@ def test_static_tensor_get_data_same_shape(): def run(dtype, shape): mod = tvm.IRModule() p = Prelude(mod) - from tvm.relay.prelude import StaticTensorArrayOps static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) static_tensor_array_ops.register() @@ -1363,7 +1350,6 @@ def test_static_tensor_get_data_expand_shape(): def run(dtype, shape): mod = tvm.IRModule() p = Prelude(mod) - from tvm.relay.prelude import StaticTensorArrayOps static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) static_tensor_array_ops.register() @@ -1391,7 +1377,6 @@ def test_static_tensor_get_data_replace_shape(): def run(dtype, shape): mod = tvm.IRModule() p = Prelude(mod) - from tvm.relay.prelude import StaticTensorArrayOps static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) static_tensor_array_ops.register()