From fdb37e23dd1b66fd6f94c1c2e6a7e1b244655afe Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Tue, 31 Mar 2020 18:04:14 +0000 Subject: [PATCH] Improve shape parsing --- python/tvm/relay/prelude.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index a122408899d4..24eed31d9c5c 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -38,9 +38,7 @@ def __init__(self, prelude, dtype, shape): def get_name(self, canonical): """Get name corresponding to the canonical name""" - shape_str = str(self.shape).replace('[', '').replace(']', '')\ - .replace('(', '').replace(')', '').replace(', ', '_')\ - .replace(',', '') + shape_str = '_'.join([str(dim) for dim in self.shape]) if len(shape_str) == 0: shape_str = "scalar" if canonical == 'tensor_t': @@ -53,8 +51,8 @@ def get_var(self, canonical): return getattr(self.prelude, name) def define_tensor_adt(self): - """Defines the dynamic tensor ADT, which is the container for tensors - with variable shapes.""" + """Defines the static tensor ADT, which is the container for tensors + with fixed shapes.""" tensor_type_name = self.get_name('tensor_t') # Skip register if tensor type is already registered. global_type_names = set() @@ -100,6 +98,7 @@ def define_tensor_take(self): tensor_take(t, lower, upper) : tensor_t -> Tensor[(), int32] -> Tensor[(), int32] -> tensor_t """ + # We don't register take for scalar tensor. ndim = len(self.shape) if ndim == 0: return @@ -107,6 +106,7 @@ def define_tensor_take(self): take_name = self.get_name("tensor_take") take_var = self._create_global_var(take_name) setattr(self.prelude, take_name, take_var) + origin_tensor_constructor = self.get_var('tensor_constructor') output_shape = [Any(),] + list(self.shape[1:]) tensor_type_var, tensor_constructor = \ @@ -116,7 +116,7 @@ def define_tensor_take(self): lower = Var('lower', scalar_type('int32')) upper = Var('upper', scalar_type('int32')) tvar = Var('t') - case = Clause(PatternConstructor(self.get_var('tensor_constructor'), [PatternVar(tvar)]), + case = Clause(PatternConstructor(origin_tensor_constructor, [PatternVar(tvar)]), tensor_constructor(op.take(tvar, op.arange(lower, upper, dtype='int32'), axis=0))) @@ -128,6 +128,7 @@ def define_tensor_concatenate(self): """Defines a function to concatenate two tensor_t on axis 0. tensor_concatenate(t) : tensor_t -> tensor_t -> tensor_t """ + # We don't register concatenate for scalar tensor. ndim = len(self.shape) if ndim == 0: return @@ -180,7 +181,7 @@ def define_tensor_expand_dims(self): Function([x], Match(x, [case], False), tensor_type_var(), []) def define_tensor_array_read(self): - """Defines a function to get the head of a list. Assume the list has at least one + """Defines a function to get the nth element of a list. Assume the list has at least one element. tensor_array_read(ta, n) : list[static_tensor_t] -> Tensor[(), int32] -> Tensor[self.shape, self.dtype] @@ -217,7 +218,7 @@ def define_tensor_array_unstack(self): tensor_array_unstack_tensor(t) : tensor_t -> list[tensor_t] """ ndim = len(self.shape) - # Skip scalar case + # We don't register unstask for scalar tensor array if ndim == 0: return @@ -254,7 +255,7 @@ def define_tensor_array_scatter(self, indices_shape=None, force_update=False): list[tensor_t] -> Tensor[(Any), int32] -> tensor_t -> list[tensor_t] Set static indices shape by specifying indices_shape. - Set for_update to get static indices shape operator. + Set force_update to get static indices shape operator. """ # When this operator has already been registered, only update # when force_update is set. This should be used only when we need to @@ -310,7 +311,8 @@ def define_tensor_array_split(self, tensor_array_split(ta, value, lengths) : list[tensor_t] -> tensor_t -> Tensor[(Any), int32] -> list[tensor_t] - Set static value and lengths shapes by specifying lengths_shape. + Set static value and lengths shapes by specifying value_shape and lengths_shape. + Set force_update to get static value and lengths shape operator. """ # Skip scalar case ndim = len(self.shape) @@ -403,7 +405,7 @@ def define_tensor_array_concat(self): """Defines a function to return the values in the tensor array as concatenated tensor_t. tensor_array_concat(ta) : list[tensor_t] -> tensor_t """ - # Skip scalar case + # We don't register concat for scalar tensor array. ndim = len(self.shape) if ndim == 0: return @@ -548,7 +550,7 @@ def _get_adt_by_shape(self, shape): return tensor_type_var, tensor_constructor def _create_global_var(self, name): - """Create a GlobalVar if not show in prelude.""" + """Create a GlobalVar if doesn't exist in prelude.""" global_var_name_set = set() for g_var_name in self.prelude.mod.get_global_vars(): global_var_name_set.add(g_var_name.name_hint) @@ -1202,7 +1204,7 @@ def get_name_static(self, canonical, dtype, shape): """Get name corresponding to the canonical name""" if canonical == 'tensor_t': return 'static_tensor_{}_{}_t'.format(dtype, shape) - shape_str = str(shape).replace('[', '').replace(']', '').replace(', ', '_') + shape_str = '_'.join([str(dim) for dim in shape]) if len(shape_str) == 0: shape_str = "scalar" return "{}_{}_{}".format(canonical, dtype, shape_str)