Skip to content

Commit

Permalink
Unify get_static_name
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinthesun committed Apr 2, 2020
1 parent 0e56545 commit 4a9dd1d
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions python/tvm/relay/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@
from . import op


def _get_name_static(canonical, dtype, shape):
"""Get name for static shape tensor array op corresponding
to the canonical name"""
shape_str = '_'.join([str(dim) for dim in shape])
if len(shape_str) == 0:
shape_str = "scalar"
if canonical == 'tensor_t':
return 'static_tensor_{}_{}_t'.format(dtype, shape_str)
return "{}_{}_{}".format(canonical, dtype, shape_str)

class StaticTensorArrayOps(object):
"""Contains tensor array related ops for fixed rank tensor array"""

Expand All @@ -38,12 +48,7 @@ def __init__(self, prelude, dtype, shape):

def get_name(self, canonical):
"""Get name corresponding to the canonical name"""
shape_str = '_'.join([str(dim) for dim in self.shape])
if len(shape_str) == 0:
shape_str = "scalar"
if canonical == 'tensor_t':
return 'static_tensor_{}_{}_t'.format(self.dtype, shape_str)
return "{}_{}_{}".format(canonical, self.dtype, shape_str)
return _get_name_static(canonical, self.dtype, self.shape)

def get_var(self, canonical):
"""Get var corresponding to the canonical name"""
Expand Down Expand Up @@ -218,7 +223,7 @@ def define_tensor_array_unstack(self):
tensor_array_unstack_tensor(t) : tensor_t -> list[tensor_t]
"""
ndim = len(self.shape)
# We don't register unstask for scalar tensor array
# We don't register unstack for scalar tensor array
if ndim == 0:
return

Expand Down Expand Up @@ -1202,12 +1207,7 @@ def get_var(self, canonical, dtype):

def get_name_static(self, canonical, dtype, shape):
"""Get name corresponding to the canonical name"""
if canonical == 'tensor_t':
return 'static_tensor_{}_{}_t'.format(dtype, shape)
shape_str = '_'.join([str(dim) for dim in shape])
if len(shape_str) == 0:
shape_str = "scalar"
return "{}_{}_{}".format(canonical, dtype, shape_str)
return _get_name_static(canonical, dtype, shape)

def get_var_static(self, canonical, dtype, shape):
"""Get var corresponding to the canonical name"""
Expand Down

0 comments on commit 4a9dd1d

Please sign in to comment.