Skip to content

Commit

Permalink
Improve shape parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinthesun committed Apr 3, 2020
1 parent d22c92e commit fdb37e2
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions python/tvm/relay/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ def __init__(self, prelude, dtype, shape):

def get_name(self, canonical):
"""Get name corresponding to the canonical name"""
shape_str = str(self.shape).replace('[', '').replace(']', '')\
.replace('(', '').replace(')', '').replace(', ', '_')\
.replace(',', '')
shape_str = '_'.join([str(dim) for dim in self.shape])
if len(shape_str) == 0:
shape_str = "scalar"
if canonical == 'tensor_t':
Expand All @@ -53,8 +51,8 @@ def get_var(self, canonical):
return getattr(self.prelude, name)

def define_tensor_adt(self):
"""Defines the dynamic tensor ADT, which is the container for tensors
with variable shapes."""
"""Defines the static tensor ADT, which is the container for tensors
with fixed shapes."""
tensor_type_name = self.get_name('tensor_t')
# Skip register if tensor type is already registered.
global_type_names = set()
Expand Down Expand Up @@ -100,13 +98,15 @@ def define_tensor_take(self):
tensor_take(t, lower, upper) :
tensor_t -> Tensor[(), int32] -> Tensor[(), int32] -> tensor_t
"""
# We don't register take for scalar tensor.
ndim = len(self.shape)
if ndim == 0:
return

take_name = self.get_name("tensor_take")
take_var = self._create_global_var(take_name)
setattr(self.prelude, take_name, take_var)
origin_tensor_constructor = self.get_var('tensor_constructor')

output_shape = [Any(),] + list(self.shape[1:])
tensor_type_var, tensor_constructor = \
Expand All @@ -116,7 +116,7 @@ def define_tensor_take(self):
lower = Var('lower', scalar_type('int32'))
upper = Var('upper', scalar_type('int32'))
tvar = Var('t')
case = Clause(PatternConstructor(self.get_var('tensor_constructor'), [PatternVar(tvar)]),
case = Clause(PatternConstructor(origin_tensor_constructor, [PatternVar(tvar)]),
tensor_constructor(op.take(tvar,
op.arange(lower, upper, dtype='int32'),
axis=0)))
Expand All @@ -128,6 +128,7 @@ def define_tensor_concatenate(self):
"""Defines a function to concatenate two tensor_t on axis 0.
tensor_concatenate(t) : tensor_t -> tensor_t -> tensor_t
"""
# We don't register concatenate for scalar tensor.
ndim = len(self.shape)
if ndim == 0:
return
Expand Down Expand Up @@ -180,7 +181,7 @@ def define_tensor_expand_dims(self):
Function([x], Match(x, [case], False), tensor_type_var(), [])

def define_tensor_array_read(self):
"""Defines a function to get the head of a list. Assume the list has at least one
"""Defines a function to get the nth element of a list. Assume the list has at least one
element.
tensor_array_read(ta, n) : list[static_tensor_t] -> Tensor[(), int32] ->
Tensor[self.shape, self.dtype]
Expand Down Expand Up @@ -217,7 +218,7 @@ def define_tensor_array_unstack(self):
tensor_array_unstack_tensor(t) : tensor_t -> list[tensor_t]
"""
ndim = len(self.shape)
# Skip scalar case
# We don't register unstask for scalar tensor array
if ndim == 0:
return

Expand Down Expand Up @@ -254,7 +255,7 @@ def define_tensor_array_scatter(self, indices_shape=None, force_update=False):
list[tensor_t] -> Tensor[(Any), int32] -> tensor_t -> list[tensor_t]
Set static indices shape by specifying indices_shape.
Set for_update to get static indices shape operator.
Set force_update to get static indices shape operator.
"""
# When this operator has already been registered, only update
# when force_update is set. This should be used only when we need to
Expand Down Expand Up @@ -310,7 +311,8 @@ def define_tensor_array_split(self,
tensor_array_split(ta, value, lengths) :
list[tensor_t] -> tensor_t -> Tensor[(Any), int32] -> list[tensor_t]
Set static value and lengths shapes by specifying lengths_shape.
Set static value and lengths shapes by specifying value_shape and lengths_shape.
Set force_update to get static value and lengths shape operator.
"""
# Skip scalar case
ndim = len(self.shape)
Expand Down Expand Up @@ -403,7 +405,7 @@ def define_tensor_array_concat(self):
"""Defines a function to return the values in the tensor array as concatenated tensor_t.
tensor_array_concat(ta) : list[tensor_t] -> tensor_t
"""
# Skip scalar case
# We don't register concat for scalar tensor array.
ndim = len(self.shape)
if ndim == 0:
return
Expand Down Expand Up @@ -548,7 +550,7 @@ def _get_adt_by_shape(self, shape):
return tensor_type_var, tensor_constructor

def _create_global_var(self, name):
"""Create a GlobalVar if not show in prelude."""
"""Create a GlobalVar if doesn't exist in prelude."""
global_var_name_set = set()
for g_var_name in self.prelude.mod.get_global_vars():
global_var_name_set.add(g_var_name.name_hint)
Expand Down Expand Up @@ -1202,7 +1204,7 @@ def get_name_static(self, canonical, dtype, shape):
"""Get name corresponding to the canonical name"""
if canonical == 'tensor_t':
return 'static_tensor_{}_{}_t'.format(dtype, shape)
shape_str = str(shape).replace('[', '').replace(']', '').replace(', ', '_')
shape_str = '_'.join([str(dim) for dim in shape])
if len(shape_str) == 0:
shape_str = "scalar"
return "{}_{}_{}".format(canonical, dtype, shape_str)
Expand Down

0 comments on commit fdb37e2

Please sign in to comment.