Skip to content

Commit

Permalink
Improve test
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinthesun committed Mar 24, 2020
1 parent ee6c7e7 commit 92d7eaf
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 17 deletions.
3 changes: 2 additions & 1 deletion python/tvm/relay/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def __init__(self, prelude, dtype, shape):

def get_name(self, canonical):
"""Get name corresponding to the canonical name"""
shape_str = str(self.shape).replace('[', '').replace(']', '').replace(', ', '_')
shape_str = str(self.shape).replace('[', '').replace(']', '')\
.replace('(', '').replace(')', '').replace(', ', '_')
if len(shape_str) == 0:
shape_str = "scalar"
if canonical == 'tensor_t':
Expand Down
17 changes: 1 addition & 16 deletions tests/python/relay/test_adt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from tvm import relay
from tvm.relay.backend.interpreter import ConstructorValue
from tvm.relay import create_executor
from tvm.relay.prelude import Prelude
from tvm.relay.prelude import Prelude, StaticTensorArrayOps
from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr

import numpy as np
Expand Down Expand Up @@ -984,7 +984,6 @@ def test_static_tensor_take():
def run(dtype, shape):
mod = tvm.IRModule()
p = Prelude(mod)
from tvm.relay.prelude import StaticTensorArrayOps
static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
static_tensor_array_ops.register()

Expand All @@ -1006,7 +1005,6 @@ def test_static_tensor_concatenate():
def run(dtype, shape):
mod = tvm.IRModule()
p = Prelude(mod)
from tvm.relay.prelude import StaticTensorArrayOps
static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
static_tensor_array_ops.register()

Expand All @@ -1028,7 +1026,6 @@ def run(dtype, shape):
x = relay.var('x')
mod = tvm.IRModule()
p = Prelude(mod)
from tvm.relay.prelude import StaticTensorArrayOps
static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
static_tensor_array_ops.register()

Expand All @@ -1046,7 +1043,6 @@ def run(dtype, shape):
x = relay.var('x')
mod = tvm.IRModule()
p = Prelude(mod)
from tvm.relay.prelude import StaticTensorArrayOps
static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
static_tensor_array_ops.register()
tensor_constructor = p.get_name_static('tensor_constructor', dtype, shape)
Expand All @@ -1057,7 +1053,6 @@ def test_static_tensor_array_read():
def run(dtype, shape):
mod = tvm.IRModule()
p = Prelude(mod)
from tvm.relay.prelude import StaticTensorArrayOps
static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
static_tensor_array_ops.register()

Expand Down Expand Up @@ -1096,7 +1091,6 @@ def test_static_tensor_array_write():
def run(dtype, shape):
mod = tvm.IRModule()
p = Prelude(mod)
from tvm.relay.prelude import StaticTensorArrayOps
static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
static_tensor_array_ops.register()

Expand All @@ -1122,7 +1116,6 @@ def test_static_tensor_array_unstack():
def run(dtype, shape):
mod = tvm.IRModule()
p = Prelude(mod)
from tvm.relay.prelude import StaticTensorArrayOps
static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
static_tensor_array_ops.register()

Expand All @@ -1139,7 +1132,6 @@ def test_static_tensor_array_scatter():
def run(dtype, shape):
mod = tvm.IRModule()
p = Prelude(mod)
from tvm.relay.prelude import StaticTensorArrayOps
static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
static_tensor_array_ops.register()

Expand Down Expand Up @@ -1191,7 +1183,6 @@ def test_static_tensor_array_split():
def run(dtype, shape):
mod = tvm.IRModule()
p = Prelude(mod)
from tvm.relay.prelude import StaticTensorArrayOps
static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
static_tensor_array_ops.register()

Expand Down Expand Up @@ -1245,7 +1236,6 @@ def test_static_tensor_array_concat():
def run(dtype, shape):
mod = tvm.IRModule()
p = Prelude(mod)
from tvm.relay.prelude import StaticTensorArrayOps
static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
static_tensor_array_ops.register()

Expand All @@ -1271,7 +1261,6 @@ def test_static_tensor_array_gather():
def run(dtype, shape):
mod = tvm.IRModule()
p = Prelude(mod)
from tvm.relay.prelude import StaticTensorArrayOps
static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
static_tensor_array_ops.register()

Expand All @@ -1298,7 +1287,6 @@ def test_static_tensor_array_stack():
def run(dtype, shape):
mod = tvm.IRModule()
p = Prelude(mod)
from tvm.relay.prelude import StaticTensorArrayOps
static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
static_tensor_array_ops.register()

Expand All @@ -1323,7 +1311,6 @@ def test_static_tensor_get_data_same_shape():
def run(dtype, shape):
mod = tvm.IRModule()
p = Prelude(mod)
from tvm.relay.prelude import StaticTensorArrayOps
static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
static_tensor_array_ops.register()

Expand Down Expand Up @@ -1363,7 +1350,6 @@ def test_static_tensor_get_data_expand_shape():
def run(dtype, shape):
mod = tvm.IRModule()
p = Prelude(mod)
from tvm.relay.prelude import StaticTensorArrayOps
static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
static_tensor_array_ops.register()

Expand Down Expand Up @@ -1391,7 +1377,6 @@ def test_static_tensor_get_data_replace_shape():
def run(dtype, shape):
mod = tvm.IRModule()
p = Prelude(mod)
from tvm.relay.prelude import StaticTensorArrayOps
static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
static_tensor_array_ops.register()

Expand Down

0 comments on commit 92d7eaf

Please sign in to comment.