From 7083ef8fba0f58dd8c88ed777670498b09d3c4d7 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Wed, 14 Aug 2019 17:49:40 -0700 Subject: [PATCH 01/12] [Relay][Frontend][TF] Add tensor array ops --- python/tvm/relay/frontend/tensorflow.py | 91 +++- python/tvm/relay/op/_tensor.py | 26 + python/tvm/relay/prelude.py | 512 ++++++++++++++++++ python/tvm/relay/testing/py_converter.py | 8 +- src/relay/backend/vm/serializer.cc | 439 +++++++++++++++ .../frontend/tensorflow/test_forward.py | 118 +++- tests/python/relay/test_adt.py | 174 +++++- tests/python/relay/test_feature.py | 3 +- 8 files changed, 1361 insertions(+), 10 deletions(-) create mode 100644 src/relay/backend/vm/serializer.cc diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 38f9c523e0b1..9a7b471e9095 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -22,10 +22,14 @@ import warnings from collections import defaultdict + # Numpy support import numpy as np import tvm + +from tvm.relay.prelude import Prelude + from .. import analysis from .. import expr as _expr from .. import op as _op @@ -508,6 +512,78 @@ def _impl(inputs, attr, params): return _op.concatenate(inputs_reshaped, axis) return _impl +def _tensor_array(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr.get('dtype').name + tensor_array_constructor = getattr(prelude, "tensor_array_{}".format(dtype_str)) + return tensor_array_constructor(_op.take(inputs[0], tvm.relay.const(0))) + return _impl + +def _tensor_array_scatter(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr.get('T').name + values_rank = len(inputs[2].type_annotation.shape) + unstack_function_name = "tensor_array_unstack_tensor{}_{}".format(values_rank, dtype_str) + + values = getattr(prelude, unstack_function_name)(inputs[2]) + + tensor_array_scatter_name = "tensor_array_scatter_{}".format(dtype_str) + tensor_array_scatter_var = getattr(prelude, tensor_array_scatter_name) + return tensor_array_scatter_var(inputs[0], inputs[1], values) + return _impl + +def _tensor_array_gather(): + def _impl(inputs, attr, params, prelude): + return prelude.tensor_array_gather(inputs[2], inputs[1]) + return _impl + +def _tensor_array_size(): + def _impl(inputs, attr, params, prelude): + return prelude.length(inputs[0]) + return _impl + +def _tensor_array_write(): + def _impl(inputs, attr, params, prelude): + input_rank = len(inputs[2].type_annotation.shape) + dtype = attr.get('T').name + + tensor_name = 'tensor{}_{}'.format(input_rank, dtype) + tensor_func = getattr(prelude, tensor_name) + v = tensor_func(inputs[2]) + + write_name = 'tensor_array_write_{}'.format(dtype) + write_func = getattr(prelude, write_name) + + return write_func(inputs[3], _op.take(inputs[1], tvm.relay.const(0)), v) + return _impl + +def _tensor_array_read(): + def _impl(inputs, attr, params, prelude): + read_name = 'tensor_array_read_{}'.format(attr.get('dtype').name) + read_func = getattr(prelude, read_name) + return read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0))) + return _impl + +def _tensor_array_split(): + def _impl(inputs, attr, params, prelude): + input_rank = len(inputs[1].type_annotation.shape) + dtype_str = attr.get('T').name + tensor_constructor_name = "tensor{}_{}".format(input_rank, dtype_str) + v = getattr(prelude, tensor_constructor_name)(inputs[1]) + lengths = _op.cast(inputs[2], 'int32') + + split_name = "tensor_array_split_{}".format(dtype_str) + split_var = getattr(prelude, split_name) + return split_var(inputs[0], v, lengths) + return _impl + +def _tensor_array_concat(): + def _impl(inputs, attr, params, prelude): + concat_name = 'tensor_array_concat_{}'.format(attr['dtype'].name) + concat_func = getattr(prelude, concat_name) + return concat_func(inputs[1]) + return _impl + def _tile(): def _impl(inputs, attr, params): reps = _get_list_param(params, inputs.pop()) @@ -1313,6 +1389,14 @@ def _impl(inputs, attr, params): 'NotEqual' : _broadcast('not_equal'), 'OneHot' : _one_hot(), 'Pack' : _pack(), + 'TensorArrayV3' : _tensor_array(), + 'TensorArrayScatterV3' : _tensor_array_scatter(), + 'TensorArrayGatherV3' : _tensor_array_gather(), + 'TensorArraySizeV3' : _tensor_array_size(), + 'TensorArrayWriteV3' : _tensor_array_write(), + 'TensorArrayReadV3' : _tensor_array_read(), + 'TensorArraySplitV3' : _tensor_array_split(), + 'TensorArrayConcatV3' : _tensor_array_concat(), 'Pad' : _pad('Pad'), 'PadV2' : _pad('PadV2'), 'Pow' : _elemwise('power'), @@ -1860,6 +1944,7 @@ def __init__(self): self._loops = {} self._branches = {} self._mod = _module.Module({}) + self._prelude = Prelude(self._mod) def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): """Construct relay nodes from tensorflow graph definition - GraphDef. @@ -2335,7 +2420,11 @@ def _convert_operator(self, op_name, inputs, attrs, if op_name in identity_list: sym = get_relay_op(op_name)(*inputs, **attrs) elif op_name in convert_map: - sym = convert_map[op_name](inputs, attrs, self._params) + if 'TensorArray' in op_name: + sym = convert_map[op_name](inputs, attrs, self._params, self._prelude) + else: + sym = convert_map[op_name](inputs, attrs, self._params) + elif op_name in convert_map_rnn: sym = self._convert_rnn_operator(op_name, inputs, attrs, self._params, graph, diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index da5804906269..188b3bb15956 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -108,6 +108,29 @@ def clip_compute(attrs, inputs, output_type, target): register_schedule("clip", schedule_elemwise) +@script +def _cast_shape_function(x): + out_ndim = len(x) + out = output_tensor((out_ndim,), "int64") + for i in const_range(out_ndim): + out[i] = x[i] + return out + +def cast_shape_func(attrs, inputs, out_ndims): + return [_cast_shape_function(*inputs)] + +@script +def _expand_dims_shape_func(x): + ndim = len(x.shape) + out = output_tensor((ndim+1,), "int64") + out[0] = int64(1) + for i in const_range(0, ndim): + out[i+1] = int64(x.shape[i]) + return out + +def expand_dims_shape_func(attrs, inputs, out_ndims): + return [_expand_dims_shape_func(*inputs)] + # shape func @script def _broadcast_shape_func(x, y, ndim): @@ -140,6 +163,9 @@ def _broadcast_shape_func(x, y, ndim): def broadcast_shape_func(attrs, inputs, out_ndims): return [_broadcast_shape_func(*inputs, out_ndims[0])] +register_shape_func("expand_dims", False, expand_dims_shape_func) +register_shape_func("cast", False, cast_shape_func) + register_shape_func("add", False, broadcast_shape_func) register_shape_func("subtract", False, broadcast_shape_func) register_shape_func("multiply", False, broadcast_shape_func) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 803d8ef50db5..6cdc5981ef9b 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -16,8 +16,516 @@ # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """A prelude containing useful global functions and ADT definitions.""" +from .ty import GlobalTypeVar, TensorType, Any, scalar_type +from .expr import Var, Function, GlobalVar, If, const +from .op.tensor import add, subtract, equal +from .adt import Constructor, TypeData, Clause, Match +from .adt import PatternConstructor, PatternVar, PatternWildcard +from . import op from .module import Module +class TensorArrayOps: + """Contains tensor array related ops""" + + def __init__(self, prelude, dtype): + """Create tensor array ops registry""" + self.prelude = prelude + self.dtype = dtype + + def get_name(self, canonical): + """Get name corresponding to the caninical name""" + if canonical == 'tensor_t': + return 'tensor_{}_t'.format(self.dtype) + return "{}_{}".format(canonical, self.dtype) + + def get_var(self, canonical): + """Get var corresponding to the caninical name""" + name = self.get_name(canonical) + return getattr(self.prelude, name) + + def define_tensor_adt(self): + """Defines the dynamic tensor ADT, which is the container for tensors + with variable shapes.""" + tensor_type_name = self.get_name('tensor_t') + tensor_type_var = GlobalTypeVar(tensor_type_name) + setattr(self.prelude, tensor_type_name, tensor_type_var) + tensor0_type = TensorType([], self.dtype) + tensor1_type = TensorType([Any()], self.dtype) + tensor2_type = TensorType([Any(), Any()], self.dtype) + tensor3_type = TensorType([Any(), Any(), Any()], self.dtype) + tensor4_type = TensorType([Any(), Any(), Any(), Any()], self.dtype) + tensor5_type = TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype) + tensor6_type = TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype) + tensor_nil_name = self.get_name('tensor_nil') + tensor0_name = self.get_name('tensor0') + tensor1_name = self.get_name('tensor1') + tensor2_name = self.get_name('tensor2') + tensor3_name = self.get_name('tensor3') + tensor4_name = self.get_name('tensor4') + tensor5_name = self.get_name('tensor5') + tensor6_name = self.get_name('tensor6') + tensor_nil_case = Constructor(tensor_nil_name, [], tensor_type_var) + tensor0_case = Constructor(tensor0_name, [tensor0_type], tensor_type_var) + tensor1_case = Constructor(tensor1_name, [tensor1_type], tensor_type_var) + tensor2_case = Constructor(tensor2_name, [tensor2_type], tensor_type_var) + tensor3_case = Constructor(tensor3_name, [tensor3_type], tensor_type_var) + tensor4_case = Constructor(tensor4_name, [tensor4_type], tensor_type_var) + tensor5_case = Constructor(tensor5_name, [tensor5_type], tensor_type_var) + tensor6_case = Constructor(tensor6_name, [tensor6_type], tensor_type_var) + setattr(self.prelude, tensor_nil_name, tensor_nil_case) + setattr(self.prelude, tensor0_name, tensor0_case) + setattr(self.prelude, tensor1_name, tensor1_case) + setattr(self.prelude, tensor2_name, tensor2_case) + setattr(self.prelude, tensor3_name, tensor3_case) + setattr(self.prelude, tensor4_name, tensor4_case) + setattr(self.prelude, tensor5_name, tensor5_case) + setattr(self.prelude, tensor6_name, tensor6_case) + self.prelude.mod[tensor_type_var] = TypeData(tensor_type_var, [], [tensor_nil_case, + tensor0_case, + tensor1_case, + tensor2_case, + tensor3_case, + tensor4_case, + tensor5_case, + tensor6_case]) + + def define_tensor_take(self): + """Defines a function to return a range of tensor_t on axis 0. + tensor_take(t, lower, upper) : + tensor_t -> Tensor[(), int32] -> Tensor[(), int32] -> tensor_t + """ + take_name = self.get_name("tensor_take") + take_var = GlobalVar(take_name) + setattr(self.prelude, take_name, take_var) + tensor_t = self.get_var('tensor_t') + tensor1_var = self.get_var('tensor1') + tensor2_var = self.get_var('tensor2') + tensor3_var = self.get_var('tensor3') + tensor4_var = self.get_var('tensor4') + tensor5_var = self.get_var('tensor5') + tensor6_var = self.get_var('tensor6') + t = Var('tensor', tensor_t()) + lower = Var('lower', scalar_type('int32')) + upper = Var('upper', scalar_type('int32')) + t1 = Var('t1') + t2 = Var('t2') + t3 = Var('t3') + t4 = Var('t4') + t5 = Var('t5') + t6 = Var('t6') + tensor1_case =\ + Clause(PatternConstructor(tensor1_var, [PatternVar(t1)]), + tensor1_var(op.take(t1, op.arange(lower, upper, dtype='int32')))) + tensor2_case =\ + Clause(PatternConstructor(tensor2_var, [PatternVar(t2)]), + tensor2_var(op.take(t2, op.arange(lower, upper, dtype='int32'), axis=0))) + tensor3_case =\ + Clause(PatternConstructor(tensor3_var, [PatternVar(t3)]), + tensor3_var(op.take(t3, op.arange(lower, upper, dtype='int32'), axis=0))) + tensor4_case =\ + Clause(PatternConstructor(tensor4_var, [PatternVar(t4)]), + tensor4_var(op.take(t4, op.arange(lower, upper, dtype='int32'), axis=0))) + tensor5_case =\ + Clause(PatternConstructor(tensor5_var, [PatternVar(t5)]), + tensor5_var(op.take(t5, op.arange(lower, upper, dtype='int32'), axis=0))) + tensor6_case =\ + Clause(PatternConstructor(tensor6_var, [PatternVar(t6)]), + tensor6_var(op.take(t6, op.arange(lower, upper, dtype='int32'), axis=0))) + self.prelude.mod[take_var] =\ + Function([t, lower, upper], + Match(t, [tensor1_case, + tensor2_case, + tensor3_case, + tensor4_case, + tensor5_case, + tensor6_case], False), + tensor_t(), []) + + def define_tensor_add_one(self): + """Defines a function to grow a tensor_t's rank by adding one dimention in front + of the original tensor_t. + tensor_add_one(t) : tensor_t -> tensor_t + """ + add_one_name = self.get_name("tensor_add_one") + add_one_var = GlobalVar(add_one_name) + setattr(self.prelude, add_one_name, add_one_var) + tensor_type_var = self.get_var('tensor_t') + x = Var("x", tensor_type_var()) + t0 = Var("t0") + t1 = Var("t1") + t2 = Var("t2") + t3 = Var("t3") + t4 = Var("t4") + t5 = Var("t5") + tensor0_var = self.get_var('tensor0') + tensor1_var = self.get_var('tensor1') + tensor2_var = self.get_var('tensor2') + tensor3_var = self.get_var('tensor3') + tensor4_var = self.get_var('tensor4') + tensor5_var = self.get_var('tensor5') + tensor6_var = self.get_var('tensor6') + tensor0_case = Clause(PatternConstructor(tensor0_var, [PatternVar(t0)]), + tensor1_var(op.expand_dims(t0, 0, 1))) + tensor1_case = Clause(PatternConstructor(tensor1_var, [PatternVar(t1)]), + tensor2_var(op.expand_dims(t1, 0, 1))) + tensor2_case = Clause(PatternConstructor(tensor2_var, [PatternVar(t2)]), + tensor3_var(op.expand_dims(t2, 0, 1))) + tensor3_case = Clause(PatternConstructor(tensor3_var, [PatternVar(t3)]), + tensor4_var(op.expand_dims(t3, 0, 1))) + tensor4_case = Clause(PatternConstructor(tensor4_var, [PatternVar(t4)]), + tensor5_var(op.expand_dims(t4, 0, 1))) + tensor5_case = Clause(PatternConstructor(tensor5_var, [PatternVar(t5)]), + tensor6_var(op.expand_dims(t5, 0, 1))) + self.prelude.mod[add_one_var] =\ + Function([x], + Match(x, [tensor0_case, + tensor1_case, + tensor2_case, + tensor3_case, + tensor4_case, + tensor5_case], False)) + + def define_tensor_concat(self): + """Defines a function to concatenate two tensor_t on the first axis + + tensor_concatenate(t) : tensor_t -> tensor_t -> tensor_t + """ + concat_name = self.get_name("tensor_concatenate") + concat_var = GlobalVar(concat_name) + setattr(self.prelude, concat_name, concat_var) + tensor_type_var = self.get_var('tensor_t') + x = Var("x", tensor_type_var()) + y = Var("y", tensor_type_var()) + + tensor1_var = self.get_var('tensor1') + tensor2_var = self.get_var('tensor2') + tensor3_var = self.get_var('tensor3') + tensor4_var = self.get_var('tensor4') + t11 = Var("t11") + t12 = Var("t12") + t21 = Var("t21") + t22 = Var("t22") + t31 = Var("t31") + t32 = Var("t32") + t41 = Var("t41") + t42 = Var("t42") + tensor1_case = Clause(PatternConstructor(tensor1_var, [PatternVar(t11)]), + Match(y, [Clause(PatternConstructor(tensor1_var, [PatternVar(t12)]), + tensor1_var(op.concatenate([t11, t12], axis=0)))], + False)) + tensor2_case = Clause(PatternConstructor(tensor2_var, [PatternVar(t21)]), + Match(y, [Clause(PatternConstructor(tensor2_var, [PatternVar(t22)]), + tensor2_var(op.concatenate([t21, t22], axis=0)))], + False)) + tensor3_case = Clause(PatternConstructor(tensor3_var, [PatternVar(t31)]), + Match(y, [Clause(PatternConstructor(tensor3_var, [PatternVar(t32)]), + tensor3_var(op.concatenate([t31, t32], axis=0)))], + False)) + tensor4_case = Clause(PatternConstructor(tensor4_var, [PatternVar(t41)]), + Match(y, [Clause(PatternConstructor(tensor4_var, [PatternVar(t42)]), + tensor4_var(op.concatenate([t41, t42], axis=0)))], + False)) + # op.concatenate does not support tensor with rank higher than 4 + self.prelude.mod[concat_var] =\ + Function([x, y], Match(x, [tensor1_case, + tensor2_case, + tensor3_case, + tensor4_case], False)) + + def define_tensor_array(self): + """Defines a function to create a tensor array with size n. + tensor_array(n) : Tensor[(), int32] -> list[tensor_t] + """ + tensor_array_constructor_name = self.get_name("tensor_array") + tensor_array_constructor_var = GlobalVar(tensor_array_constructor_name) + setattr(self.prelude, tensor_array_constructor_name, tensor_array_constructor_var) + tensor_nil_var = self.get_var('tensor_nil') + tensor_type_var = self.get_var('tensor_t') + n = Var("x", scalar_type('int32')) + body = If(equal(n, const(0)), + self.prelude.nil(), + self.prelude.cons(tensor_nil_var(), + tensor_array_constructor_var(subtract(n, const(1))))) + self.prelude.mod[tensor_array_constructor_var] = \ + Function([n], body, self.prelude.l(tensor_type_var()), []) + + def define_tensor_array_read(self): + """Defines a function to get the head of a list. Assume the list has at least one + element. + + tensor_array_read(ta, n) : list[tensor_t] -> Tensor[(), int32] -> tensor_t + """ + read_name = self.get_name("tensor_array_read") + read_var = GlobalVar(read_name) + setattr(self.prelude, read_name, read_var) + tensor_type_var = self.get_var('tensor_t') + + tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + n = Var("x", scalar_type('int32')) + self.prelude.mod[read_var] =\ + Function([tensor_array, n], self.prelude.nth(tensor_array, n), tensor_type_var(), []) + + def define_tensor_array_write(self): + """Defines a function to update a tensor array at index n with value v. + tensor_array_write(ta, n, v) : + list[tensor_t] -> Tensor[(), int32] -> tensor_t -> list[tensor_t] + """ + write_name = self.get_name("tensor_array_write") + write_var = GlobalVar(write_name) + setattr(self.prelude, write_name, write_var) + tensor_type_var = self.get_var('tensor_t') + tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + n = Var("x", scalar_type('int32')) + v = Var("v", tensor_type_var()) + self.prelude.mod[write_var] =\ + Function([tensor_array, n, v], self.prelude.update(tensor_array, n, v), + self.prelude.l(tensor_type_var()), []) + + def define_tensor_array_unstack_tensor1(self): + """Defines a function to unstack the values of a tensor_t with rank 1 in a tensor array. + tensor_array_unstack_tensor1(t) : tensor_t -> list[tensor_t] + """ + helper_name = self.get_name("tensor_array_unstack_tensor1_helper") + helper_var = GlobalVar(helper_name) + setattr(self.prelude, helper_name, helper_var) + tensor = Var("t", TensorType([Any()], self.dtype)) + up = Var("up", scalar_type('int32')) + i = Var("i", scalar_type('int32')) + tensor_type_var = self.get_var('tensor_t') + tensor0_var = self.get_var('tensor0') + helper_body =\ + If(equal(i, up), + self.prelude.nil(), + self.prelude.cons(tensor0_var(op.take(tensor, i)), + helper_var(add(i, const(1)), up, tensor))) + self.prelude.mod[helper_var] =\ + Function([i, up, tensor], helper_body, self.prelude.l(tensor_type_var()), []) + unstack_name = self.get_name("tensor_array_unstack_tensor1") + unstack_var = GlobalVar(unstack_name) + setattr(self.prelude, unstack_name, unstack_var) + tensor1 = Var("tensor", TensorType([Any()], self.dtype)) + shape = op.shape_of(tensor1) + ndim = op.take(shape, const(0)) + self.prelude.mod[unstack_var] =\ + Function([tensor1], helper_var(const(0), ndim, tensor1), + self.prelude.l(tensor_type_var()), []) + + def define_tensor_array_unstack_tensor2(self): + """Defines a function to unstack the values of a tensor_t with rank 2 in a tensor array. + + tensor_array_unstack_tensor2(t) : tensor_t -> list[tensor_t] + """ + helper_name = self.get_name("tensor_array_unstack_tensor2_helper") + helper_var = GlobalVar(helper_name) + setattr(self.prelude, helper_name, helper_var) + tensor = Var("t", TensorType([Any(), Any()], self.dtype)) + up = Var("up", scalar_type('int32')) + i = Var("i", scalar_type('int32')) + + helper_body = If(equal(i, up), + self.prelude.nil(), + self.prelude.cons(self.get_var('tensor1')(op.take(tensor, i, axis=0)), + helper_var(add(i, const(1)), up, tensor))) + self.prelude.mod[helper_var] =\ + Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), []) + + tensor_array_unstack_tensor2_name = self.get_name("tensor_array_unstack_tensor2") + tensor_array_unstack_tensor2_var = GlobalVar(tensor_array_unstack_tensor2_name) + setattr(self.prelude, tensor_array_unstack_tensor2_name, tensor_array_unstack_tensor2_var) + tensor2 = Var("tensor", TensorType([Any(), Any()], self.dtype)) + shape = op.shape_of(tensor2) + ndim = op.take(shape, const(0)) + self.prelude.mod[tensor_array_unstack_tensor2_var] =\ + Function([tensor2], helper_var(const(0), ndim, tensor2), + self.prelude.l(self.get_var('tensor_t')()), []) + + def define_tensor_array_scatter(self): + """Defines a function to scatter the values of a tensor_t in indices of a tensor array.. + tensor_array_scatter(ta, indices, value) : + list[tensor_t] -> Tensor[(Any), int32] -> tensor_t -> list[tensor_t] + """ + tensor_array_scatter_helper_name = self.get_name("tensor_array_scatter_helper") + tensor_array_scatter_helper_var = GlobalVar(tensor_array_scatter_helper_name) + tensor_t = self.get_var('tensor_t') + ta = Var("ta", self.prelude.l(tensor_t())) + current = Var("current", scalar_type('int32')) + limit = Var("limit", scalar_type('int32')) + indices_ = Var('indices_', TensorType([Any()], 'int32')) + values_ = Var('values_', self.prelude.l(tensor_t())) + write_var = self.get_var('tensor_array_write') + read_var = self.get_var('tensor_array_read') + helper_body = If(equal(current, limit), + ta, + tensor_array_scatter_helper_var( + write_var(ta, op.take(indices_, current), + read_var(values_, current)), + add(current, const(1)), + limit, indices_, values_)) + self.prelude.mod[tensor_array_scatter_helper_var] =\ + Function([ta, current, limit, indices_, values_], + helper_body, self.prelude.l(tensor_t()), []) + tensor_array_scatter_name = self.get_name("tensor_array_scatter") + tensor_array_scatter_var = GlobalVar(tensor_array_scatter_name) + setattr(self.prelude, tensor_array_scatter_name, tensor_array_scatter_var) + tensor_array = Var("tensor_array", self.prelude.l(tensor_t())) + indices = Var('indices', TensorType([Any()], 'int32')) + values = Var('values', self.prelude.l(tensor_t())) + indices_shape = op.shape_of(indices) + limit = op.take(indices_shape, const(0)) + body = tensor_array_scatter_helper_var(tensor_array, const(0), limit, indices, values) + self.prelude.mod[tensor_array_scatter_var] =\ + Function([tensor_array, indices, values], body, self.prelude.l(tensor_t()), []) + + def define_tensor_array_split(self): + """Defines a function to split the values of a tensor_t into a tensor array. + tensor_array_split(ta, value, lengths) : + list[tensor_t] -> tensor_t -> Tensor[(Any), int32] -> list[tensor_t] + """ + tensor_t = self.get_var('tensor_t') + tensor_array_split_helper_name = self.get_name("ta_split_helper") + tensor_array_split_helper_var = GlobalVar(tensor_array_split_helper_name) + setattr(self.prelude, tensor_array_split_helper_name, tensor_array_split_helper_var) + ta1 = Var("tensor_array", self.prelude.l(tensor_t())) + value1 = Var('value1', tensor_t()) + offset1 = Var('offset1', scalar_type('int32')) + current1 = Var('current1', scalar_type('int32')) + limit1 = Var('limit1', scalar_type('int32')) + lengths1 = Var('lengths', TensorType([Any()], 'int32')) + write_var = self.get_var('tensor_array_write') + take_var = self.get_var('tensor_take') + helper1_body = If(equal(current1, limit1), + ta1, + write_var( + tensor_array_split_helper_var( + ta1, + value1, + add(offset1, op.take(lengths1, current1)), + add(current1, const(1)), + limit1, + lengths1 + ), + current1, + take_var(value1, + offset1, + add(op.take(lengths1, current1), offset1)))) + self.prelude.mod[tensor_array_split_helper_var] = \ + Function([ta1, value1, offset1, current1, limit1, lengths1], + helper1_body, self.prelude.l(tensor_t()), []) + split_name = self.get_name("tensor_array_split") + split_var = GlobalVar(split_name) + setattr(self.prelude, split_name, split_var) + tensor_array = Var("tensor_array", self.prelude.l(tensor_t())) + value = Var('value', tensor_t()) + lengths = Var('lengths', TensorType([Any()], 'int32')) + lengths_shape = op.shape_of(lengths) + lengths_limit = op.take(lengths_shape, const(0)) + body = tensor_array_split_helper_var( + tensor_array, + value, + const(0), + const(0), + lengths_limit, + lengths) + self.prelude.mod[split_var] =\ + Function([tensor_array, value, lengths], body, self.prelude.l(tensor_t()), []) + + def define_tensor_array_concat(self): + """Defines a function to return the values in the tensor array as concatenated tensor_t. + tensor_array_concat(ta) : list[tensor_t] -> tensor_t + """ + concat_name = self.get_name("tensor_array_concat") + concat_var = GlobalVar(concat_name) + setattr(self.prelude, concat_name, concat_var) + tensor_concat_var = self.get_var('tensor_concatenate') + tensor_t = self.get_var('tensor_t') + tensor_nil_var = self.get_var('tensor_nil') + tensor_array = Var("tensor_array", self.prelude.l(tensor_t())) + hd = Var("hd") + tl = Var("tl") + nil_case = Clause(PatternConstructor(self.prelude.nil), tensor_nil_var()) + cons_case = Clause(PatternConstructor(self.prelude.cons, [PatternVar(hd), PatternVar(tl)]), + Match(tl, [ + Clause(PatternConstructor(self.prelude.nil), hd), + Clause(PatternWildcard(), + tensor_concat_var(hd, concat_var(tl))) + ], False)) + self.prelude.mod[concat_var] =\ + Function([tensor_array], + Match(tensor_array, [nil_case, cons_case], False), tensor_t(), []) + + def define_tensor_array_gather(self): + """Defines a function to return the selected values in a tensor array as tensor_t. + tensor_array_gather(ta, indices) : list[tensor_t] -> Tensor[(Any), int32] -> tensor_t + """ + helper_name = self.get_name("tensor_array_gather_helper") + helper_var = GlobalVar(helper_name) + setattr(self.prelude, helper_name, helper_var) + tensor_type_var = self.get_var('tensor_t') + stack_var = self.get_var('tensor_array_stack') + read_var = self.get_var('tensor_array_read') + ta = Var("ta", self.prelude.l(tensor_type_var())) + accu = Var("accu", self.prelude.l(tensor_type_var())) + current = Var("current", scalar_type('int32')) + limit = Var("limit", scalar_type('int32')) + indices_ = Var('indices_', TensorType([Any()], 'int32')) + helper_body =\ + If(equal(current, const(0)), + stack_var(accu), + helper_var( + ta, + self.prelude.cons( + read_var( + ta, op.take(indices_, subtract(current, const(1)))), accu), + subtract(current, const(1)), + limit, indices_)) + self.prelude.mod[helper_var] = \ + Function([ta, accu, current, limit, indices_], helper_body, tensor_type_var(), []) + gather_name = self.get_name("tensor_array_gather") + gather_var = GlobalVar(gather_name) + setattr(self.prelude, gather_name, gather_var) + tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + indices = Var('indices', TensorType([Any()], 'int32')) + indices_shape = op.shape_of(indices) + limit = op.take(indices_shape, const(0)) + body = helper_var(tensor_array, self.prelude.nil(), limit, limit, indices) + self.prelude.mod[gather_var] =\ + Function([tensor_array, indices], body, tensor_type_var(), []) + + def define_tensor_array_stack(self): + """Defines a function to get the values in the tensor array as a stack tensor_t. + tensor_array_stack(l) : list[tensor_t] -> tensor_t + """ + stack_name = self.get_name("tensor_array_stack") + stack_var = GlobalVar(stack_name) + setattr(self.prelude, stack_name, stack_var) + tensor_type_var = self.get_var('tensor_t') + tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + add_one_var = self.get_var('tensor_add_one') + concat_var = self.get_var('tensor_concatenate') + tensor_array_add_one = self.prelude.map(add_one_var, tensor_array) + tensors = self.prelude.foldl(concat_var, + self.prelude.hd(tensor_array_add_one), + self.prelude.tl(tensor_array_add_one)) + self.prelude.mod[stack_var] = Function([tensor_array], tensors, tensor_type_var(), []) + + def register(self): + """Register all tensor array ops in Prelude""" + self.define_tensor_adt() + self.define_tensor_take() + self.define_tensor_add_one() + self.define_tensor_concat() + self.define_tensor_array() + self.define_tensor_array_read() + self.define_tensor_array_write() + self.define_tensor_array_unstack_tensor1() + self.define_tensor_array_unstack_tensor2() + self.define_tensor_array_scatter() + self.define_tensor_array_split() + self.define_tensor_array_concat() + self.define_tensor_array_stack() + # TODO(wweic): Gather fails in PartialEvaluate + # self.define_tensor_array_gather() + class Prelude: """Contains standard definitions.""" @@ -74,3 +582,7 @@ def load_prelude(self): ] for global_def in GLOBAL_DEFS: setattr(self, global_def, self.mod.get_global_var(global_def)) + + for dtype in ['float32', 'int32']: + tensor_array_ops = TensorArrayOps(self, dtype) + tensor_array_ops.register() diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index d661be73ad02..d7b59922b89d 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -203,8 +203,12 @@ def convert_module(self): for var, func in self.mod.functions.items(): # optimize the definition so any operators used are lowered opt_func = self.optimize(func) - converted_func, _ = self.convert_func_node(opt_func, var) - defs.append(converted_func) + try: + converted_func, _ = self.convert_func_node(opt_func, var) + defs.append(converted_func) + except TypeError: + # TODO(wweic): fix conversion for Any + pass return defs diff --git a/src/relay/backend/vm/serializer.cc b/src/relay/backend/vm/serializer.cc new file mode 100644 index 000000000000..a3d882818a89 --- /dev/null +++ b/src/relay/backend/vm/serializer.cc @@ -0,0 +1,439 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file src/relay/backend/vm/serializer.cc + * \brief Implementation of serializing APIs for the Relay VM. + */ +#include "serializer.h" + +#include +#include + +#include +#include +#include +#include +#include + +#include "serialize_util.h" + +namespace tvm { +namespace relay { +namespace vm { + +void Serializer::Init(const VirtualMachine* vm) { + vm_ = vm; + // Initialize the stream object. + strm_ = new dmlc::MemoryStringStream(&code_); +} + +runtime::PackedFunc Serializer::GetFunction( + const std::string& name, + const std::shared_ptr& sptr_to_self) { + if (name == "get_lib") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->GetLib(); + }); + } else if (name == "get_primitive_ops") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->GetPrimitiveOps(); + }); + } else if (name == "get_bytecode") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->GetBytecode(); + }); + } else if (name == "get_globals") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->GetGlobals(); + }); + } else if (name == "get_stats") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->Stats(); + }); + } else if (name == "serialize") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->Serialize(); + }); + } else { + LOG(FATAL) << "Unknown packed function: " << name; + return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); + } +} + +tvm::Array Serializer::GetPrimitiveOps() const { + std::vector ret; + for (const auto& it : vm_->primitive_map) { + auto packed_name = tvm::ir::StringImm::make(it.first); + auto packed_index = static_cast(it.second); + if (ret.size() <= packed_index) { + ret.resize(packed_index + 1); + } + ret[packed_index] = packed_name; + } + return ret; +} + +std::string Serializer::Stats() const { + std::ostringstream oss; + oss << "Relay VM statistics:" << std::endl; + + // Get the number of constants and the shape of each of them. + oss << " Constant shapes (# " << vm_->constants.size() << "): ["; + for (const auto& it : vm_->constants) { + auto* cell = it.as(); + CHECK(cell != nullptr); + runtime::NDArray data = cell->data; + const auto& shape = data.Shape(); + + // Scalar + if (shape.empty()) { + oss << "scalar, "; + continue; + } + + oss << "["; + for (auto s : shape) { + oss << s << ", "; + } + oss.seekp(-2, oss.cur); + oss << "], " << std::endl; + } + if (!vm_->constants.empty()) oss.seekp(-2, oss.cur); + oss << "]" << std::endl; + + // Get the number of globals and the name of each of them. + oss << " Globals (#" << vm_->global_map.size() << "): ["; + for (const auto& it : vm_->global_map) { + oss << "(\"" << it.first << "\", " << it.second << ")" << ", "; + } + if (!vm_->global_map.empty()) oss.seekp(-2, oss.cur); + oss << "]" << std::endl; + + // Get the number of primitive ops and the name of each of them. + oss << " Primitive ops (#" << vm_->primitive_map.size() << "): ["; + const auto& prim_ops = GetPrimitiveOps(); + for (const auto& it : prim_ops) { + oss << it << ", "; + } + if (!prim_ops.empty()) oss.seekp(-2, oss.cur); + oss << "]" << std::endl; + + return oss.str(); +} + +TVMByteArray Serializer::Serialize() { + uint64_t header = kTVMVMBytecodeMagic; + strm_->Write(header); + std::string version = TVM_VERSION; + strm_->Write(version); + + // Global section. + SerializeGlobalSection(); + + // Constant section. + SerializeConstantSection(); + + // Primitive names. + SerializePrimitiveOpNames(); + + // Code section. + SerializeCodeSection(); + + TVMByteArray arr; + arr.data = code_.c_str(); + arr.size = code_.length(); + return arr; +} + +void Serializer::SerializeGlobalSection() { + auto globals = GetGlobals(); + std::vector glbs; + for (const auto& it : globals) { + glbs.push_back(it.as()->value); + } + strm_->Write(glbs); +} + +void Serializer::SerializeConstantSection() { + std::vector arrays; + for (const auto& obj : vm_->constants) { + const auto* cell = obj.as(); + CHECK(cell != nullptr); + runtime::NDArray data = cell->data; + arrays.push_back(const_cast(data.operator->())); + } + strm_->Write(static_cast(vm_->constants.size())); + for (const auto& it : arrays) { + runtime::SaveDLTensor(strm_, it); + } +} + +void Serializer::SerializePrimitiveOpNames() { + auto names = GetPrimitiveOps(); + std::vector primitive_names; + for (const auto& it : names) { + primitive_names.push_back(it.as()->value); + } + strm_->Write(primitive_names); +} + +// Serialize a virtual machine instruction. It creates a list that contains the +// hash, opcode, and all fields of an instruction. +// +// For example, the function signature used to create an `AllocTensor` +// instruction is: +// Instruction AllocTensor(std::vector shape, DLDataType dtype, RegName dst) +// +// The serialized form will be: +// `hash 5 dtype.code dtype.bits dtype.lanes ndim dst_register val1 val2 ... valn` +// +// where hash is the hash of serialized instruction that is computed internally +// by the `VMInstructionSerializer`. It is used for sanity check before decoding. +// 5 shows opcode of `AllocTensor`, `(dtype.code dtype.bits dtype.lanes)` +// represents a `DLDataType`, `ndim` is the number of dimensions, `dst_register` +// is the destination register, and the rest of it together indicates the shape +// of the tensor to be allocated. +VMInstructionSerializer SerializeInstruction(const Instruction& instr) { + std::vector fields; + // Save the opcode. + DLOG(INFO) << "Serializing: " << instr << std::endl; + switch (instr.op) { + case Opcode::Move: { + // Number of fields = 2 + fields.assign({instr.from, instr.dst}); + break; + } + case Opcode::Ret: { + // Number of fields = 1 + fields.push_back(instr.result); + break; + } + case Opcode::Fatal: { + // Number of fields = 0 + break; + } + case Opcode::InvokePacked: { + // Number of fields = 3 + instr.arity + // Note that arity includes both input arguments and outputs. We will + // put all the `arity` number of fields in the end for serialization. + fields.assign({instr.packed_index, instr.arity, instr.output_size}); + // Save the args. + fields.insert(fields.end(), instr.packed_args, instr.packed_args + instr.arity); + break; + } + case Opcode::AllocTensor: { + // Number of fields = 5 + instr.alloc_tensor.ndim + // Save `DLDataType` and the dst register. + const auto& dtype = instr.alloc_tensor.dtype; + fields.assign({dtype.code, dtype.bits, dtype.lanes}); + + // The number of dimensions is not needed for constructing an + // `AllocTensor` instruction as it equals to the length of the `shape` + // vector. However, we save it to conveniently deserialize the instruction + // because we will know how many fields are needed by the `shape` argument. + fields.push_back(instr.alloc_tensor.ndim); + fields.push_back(instr.dst); + + // Save the shape of the tensor. + // Note that this field is rotated to the end of the list. + fields.insert(fields.end(), instr.alloc_tensor.shape, + instr.alloc_tensor.shape + instr.alloc_tensor.ndim); + break; + } + case Opcode::AllocTensorReg: { + // Number of fields = 5 + fields.push_back(instr.alloc_tensor_reg.shape_register); + // Save `DLDataType` and the dst register. + const auto& dtype = instr.alloc_tensor.dtype; + fields.insert(fields.end(), {dtype.code, dtype.bits, dtype.lanes}); + fields.push_back(instr.dst); + break; + } + case Opcode::AllocDatatype: { + // Number of fields = 3 + instr.num_fields + fields.assign({instr.constructor_tag, instr.num_fields, instr.dst}); + + // Save the fields. + fields.insert(fields.end(), instr.datatype_fields, + instr.datatype_fields + instr.num_fields); + break; + } + case Opcode::AllocClosure: { + // Number of fields = 3 + instr.num_freevar + fields.assign({instr.clo_index, instr.num_freevar, instr.dst}); + + // Save the free vars. + fields.insert(fields.end(), instr.free_vars, + instr.free_vars + instr.num_freevar); + break; + } + case Opcode::If: { + // Number of fields = 4 + fields.assign({instr.if_op.test, + instr.if_op.target, + instr.if_op.true_offset, + instr.if_op.false_offset}); + break; + } + case Opcode::Invoke: { + // Number of fields = 3 + instr.num_args + fields.assign({instr.func_index, instr.num_args, instr.dst}); + + // Save the args. + fields.insert(fields.end(), instr.invoke_args_registers, + instr.invoke_args_registers + instr.num_args); + break; + } + case Opcode::InvokeClosure: { + // Number of fields = 3 + instr.num_closure_args + fields.assign({instr.closure, instr.num_closure_args, instr.dst}); + + // Save the args. + fields.insert(fields.end(), instr.closure_args, + instr.closure_args + instr.num_closure_args); + break; + } + case Opcode::LoadConst: { + // Number of fields = 2 + fields.assign({instr.const_index, instr.dst}); + break; + } + case Opcode::LoadConsti: { + // Number of fields = 2 + fields.assign({instr.load_consti.val, instr.dst}); + break; + } + case Opcode::GetField: { + // Number of fields = 3 + fields.assign({instr.object, instr.field_index, instr.dst}); + break; + } + case Opcode::GetTag: { + // Number of fields = 2 + fields.assign({instr.get_tag.object, instr.dst}); + break; + } + case Opcode::Goto: { + // Number of fields = 1 + fields.push_back(instr.pc_offset); + break; + } + default: + LOG(FATAL) << "Invalid opcode" << static_cast(instr.op); + break; + } + + return VMInstructionSerializer(static_cast(instr.op), fields); +} + +void Serializer::SerializeCodeSection() { + // Save the number of functions. + strm_->Write(static_cast(vm_->functions.size())); + for (const auto& func : vm_->functions) { + // Serialize the function info. + VMFunctionSerializer func_format(func.name, + func.register_file_size, + func.instructions.size(), + func.params); + func_format.Save(strm_); + + // Serialize each instruction. + for (const auto& instr : func.instructions) { + const auto& serialized_instr = SerializeInstruction(instr); + serialized_instr.Save(strm_); + } + } +} + +tvm::Array Serializer::GetGlobals() const { + tvm::Array ret; + std::vector > globals(vm_->global_map.begin(), + vm_->global_map.end()); + auto comp = [](const std::pair& a, + const std::pair& b) { + return a.second < b.second; + }; + std::sort(globals.begin(), globals.end(), comp); + for (const auto& it : globals) { + ret.push_back(tvm::ir::StringImm::make(it.first)); + } + return ret; +} + +std::string Serializer::GetBytecode() const { + std::ostringstream oss; + + for (const auto& func : vm_->functions) { + // Print the header of the function format. + oss << "# func name, reg file size, param count, inst count:" + << std::endl; + oss << func.name << " " + << func.register_file_size << " " + << func.params.size() << " " + << func.instructions.size() << std::endl; + + // Print pramams of a `VMFunction`. + oss << "# Parameters:"<< std::endl; + for (const auto& param : func.params) { + oss << param << " "; + } + oss << std::endl; + + // Print the instructions of a `VMFunction`. + // The part after ";" is the instruction in text format. + oss << "hash, opcode, fields # inst(text):"<< std::endl; + for (const auto& instr : func.instructions) { + const auto& serialized_instr = SerializeInstruction(instr); + oss << std::hex << "0x" << serialized_instr.Hash() << " " + << std::dec << serialized_instr.opcode << " "; + for (auto it : serialized_instr.fields) { + oss << it << " "; + } + oss << " # " << instr; + if (oss.str().back() != '\n') oss << std::endl; + } + } + + return oss.str(); +} + +runtime::Module Serializer::GetLib() const { + return vm_->lib; +} + +runtime::Module CreateSerializer(const VirtualMachine* vm) { + std::shared_ptr exec = std::make_shared(); + exec->Init(vm); + return runtime::Module(exec); +} + +TVM_REGISTER_GLOBAL("relay._vm._Serializer") +.set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Module mod = args[0]; + const auto* vm = dynamic_cast(mod.operator->()); + CHECK(vm) << "Virtual machine has not been defined yet." + << "\n"; + *rv = CreateSerializer(vm); +}); + +} // namespace vm +} // namespace relay +} // namespace tvm diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index c2cbbff24173..3321d71a2cb8 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -60,13 +60,19 @@ def vmobj_to_list(o): result.append(vmobj_to_list(f)) return result elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue): - if o.constructor.name_hint == 'cons': + if o.constructor.name_hint == 'Cons': tl = vmobj_to_list(o.fields[1]) hd = vmobj_to_list(o.fields[0]) hd.extend(tl) return hd - elif o.constructor.name_hint == 'nil': + elif o.constructor.name_hint == 'Nil': return [] + elif 'tensor_nil' in o.constructor.name_hint: + return [0] + elif 'tensor' in o.constructor.name_hint: + return [o.fields[0].asnumpy()] + else: + raise RuntimeError("Unknown object type: %s" % o.constructor.name_hint) elif isinstance(o, tvm.relay.backend.interpreter.TensorValue): return [o.data.asnumpy()] else: @@ -77,14 +83,11 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, """ Generic function to compile on relay and execute on tvm """ input_data = convert_to_list(input_data) input_node = convert_to_list(input_node) - layout = None if target == "cuda": layout = "NCHW" target_host = None - shape_dict = {e: i.shape for e, i in zip(input_node, input_data)} - mod, params = relay.frontend.from_tensorflow(graph_def, layout=layout, shape=shape_dict, @@ -581,6 +584,111 @@ def test_forward_squeeze(): _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5]) _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5, -1]) +def test_tensor_array_constructor(): + def run(dtype_str): + with tf.Graph().as_default(): + dtype = { + 'float32': tf.float32, + 'int32' : tf.int32 + }[dtype_str] + t = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype) + t2 = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype) + ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False) + ta2 = ta1.write(0, t) + ta3 = ta2.write(1, t2) + out = ta3.read(0) + g = tf.get_default_graph() + compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug') + run('float32') + run('int32') + +def test_tensor_array_scatter(): + def run(dtype_str): + with tf.Graph().as_default(): + dtype = { + 'float32': tf.float32, + 'int32' : tf.int32 + }[dtype_str] + t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str), dtype=dtype) + indices = tf.constant([2, 1, 0]) + ta1 = tf.TensorArray(dtype=dtype, size=3, infer_shape=False, dynamic_size=False) + ta2 = ta1.scatter(indices, t) + out0 = ta2.read(0) + out1 = ta2.read(1) + out2 = ta2.read(2) + g = tf.get_default_graph() + compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug') + compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug') + compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug') + run('float32') + run('int32') + +# TODO(wweic): Fix gather issue with PartialEvaluate +# def test_tensor_array_gather(): +# with tf.Graph().as_default(): +# dtype = 'float32' +# t = tf.constant([[1.0], [2.0], [3.0]]) +# scatter_indices = tf.constant([2, 1, 0]) +# gather_indices = tf.constant([1, 2]) +# ta1 = tf.TensorArray(dtype=tf.float32, size=3, infer_shape=False, dynamic_size=False) +# ta2 = ta1.scatter(scatter_indices, t) +# t1 = ta2.gather(gather_indices) +# g = tf.get_default_graph() +# compare_tf_with_tvm([], [], ['TensorArrayGatherV3:0'], mode='debug') + +def test_tensor_array_split(): + def run(dtype_str): + with tf.Graph().as_default(): + dtype = { + 'float32': tf.float32, + 'int32' : tf.int32 + }[dtype_str] + t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype) + split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32) + ta1 = tf.TensorArray(dtype=dtype, size=4, infer_shape=False, dynamic_size=False) + ta2 = ta1.split(t, split_length) + out0 = ta2.read(0) + out1 = ta2.read(1) + out2 = ta2.read(2) + out3 = ta2.read(3) + g = tf.get_default_graph() + compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug') + compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug') + compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug') + compare_tf_with_tvm([], [], ['TensorArrayReadV3_3:0'], mode='debug') + run('float32') + run('int32') + +def test_tensor_array_concat(): + def run(dtype_str): + with tf.Graph().as_default(): + dtype = { + 'float32': tf.float32, + 'int32' : tf.int32 + }[dtype_str] + t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype) + split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32) + ta1 = tf.TensorArray(dtype=dtype, size=4, infer_shape=False, dynamic_size=False) + ta2 = ta1.split(t, split_length) + t = ta2.concat() + compare_tf_with_tvm([], [], ['TensorArrayConcatV3:0'], mode='debug') + run('float32') + run('int32') + +def test_tensor_array_size(): + def run(dtype_str): + with tf.Graph().as_default(): + dtype = { + 'float32': tf.float32, + 'int32' : tf.int32 + }[dtype_str] + ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False) + out = ta1.size() + g = tf.get_default_graph() + compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug') + run('float32') + run('int32') + ####################################################################### # ConcatV2 # -------- diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 7be7c75dfe64..284a1e01b65a 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -18,9 +18,11 @@ from tvm import relay from tvm.relay.backend.interpreter import ConstructorValue from tvm.relay import create_executor -from tvm.relay.prelude import Prelude +from tvm.relay.prelude import Prelude, TensorArrayOps from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr +import numpy as np + mod = relay.Module() p = Prelude(mod) add_nat_definitions(p) @@ -683,6 +685,170 @@ def test_iterate(): res = intrp.evaluate(relay.Function([], expr)()) assert count(res) == 12 +def test_tensor_array_add_one(): + def run(dtype): + x = relay.var('x') + mod = relay.Module() + p = Prelude(mod) + tensor_array_ops = TensorArrayOps(p, dtype) + add_one_func = tensor_array_ops.get_var('tensor_add_one') + tensor1 = tensor_array_ops.get_var('tensor1') + mod["main"] = relay.Function([x], add_one_func(tensor1(x))) + for kind in ["debug"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + x_np = np.random.uniform(size=(1,)).astype(dtype) + result = ex.evaluate()(x_np) + got = vmobj_to_list(result) + expected = [np.expand_dims(x_np, axis=0)] + tvm.testing.assert_allclose(expected, got) + run('float32') + run('int32') + +def test_tensor_array_constructor(): + def run(dtype): + x = relay.var('x') + mod = relay.Module() + p = Prelude(mod) + tensor_array_ops = TensorArrayOps(p, dtype) + tensor_array = tensor_array_ops.get_var('tensor_array') + mod["main"] = relay.Function([x], tensor_array(x)) + for kind in ["debug"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(5) + got = vmobj_to_list(result) + expected = np.array([0, 0, 0, 0, 0]) + tvm.testing.assert_allclose(expected, got) + run('float32') + run('int32') + +def test_tensor_array_read(): + def run(dtype): + mod = relay.Module() + p = Prelude(mod) + tensor_array_ops = TensorArrayOps(p, dtype) + l = relay.var('l') + i = relay.var('i') + read_func = tensor_array_ops.get_var('tensor_array_read') + tensor_array = tensor_array_ops.get_var('tensor_array') + mod["main"] = relay.Function([l, i], read_func(tensor_array(l), i)) + for kind in ["debug"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(10, 5) + got = vmobj_to_list(result) + expected = [0] + tvm.testing.assert_allclose(expected, got) + run('float32') + run('int32') + +def vmobj_to_list(o): + if isinstance(o, tvm.relay.backend.vmobj.TensorObject): + return [o.asnumpy().tolist()] + elif isinstance(o, tvm.relay.backend.interpreter.TensorValue): + return [o.asnumpy()] + elif isinstance(o, tvm.relay.backend.vmobj.DatatypeObject): + result = [] + for f in o: + result.extend(vmobj_to_list(f)) + return result + elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue): + if o.constructor.name_hint == 'Cons': + tl = vmobj_to_list(o.fields[1]) + hd = vmobj_to_list(o.fields[0]) + hd.extend(tl) + return hd + elif o.constructor.name_hint == 'Nil': + return [] + elif 'tensor_nil' in o.constructor.name_hint: + return [0] + elif 'tensor' in o.constructor.name_hint: + return [o.fields[0].asnumpy()] + else: + raise RuntimeError("Unknown object type: %s" % o.constructor.name_hint) + else: + raise RuntimeError("Unknown object type: %s" % type(o)) + +def test_tensor_array_stack(): + def run(dtype): + mod = relay.Module() + p = Prelude(mod) + tensor_array_ops = TensorArrayOps(p, dtype) + tensor_array = tensor_array_ops.get_var('tensor_array') + tensor1 = tensor_array_ops.get_var('tensor1') + write = tensor_array_ops.get_var('tensor_array_write') + stack = tensor_array_ops.get_var('tensor_array_stack') + l = relay.var('l') + v = relay.var('v') + init_tensor_array = tensor_array(relay.const(3)) + tensor_array1 = write(init_tensor_array, relay.const(0), tensor1(v)) + tensor_array2 = write(tensor_array1, relay.const(1), tensor1(v)) + tensor_array3 = write(tensor_array2, relay.const(2), tensor1(v)) + tensor_array4 = stack(tensor_array3) + mod["main"] = relay.Function([v], tensor_array4) + for kind in ["debug"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + t = np.random.uniform(size=(1,)).astype(dtype) + result = ex.evaluate()(t) + res = vmobj_to_list(result) + expected = [np.stack([t, t, t])] + tvm.testing.assert_allclose(expected, res) + run('float32') + run('int32') + +def test_tensor_array_unstack(): + def run(dtype): + mod = relay.Module() + p = Prelude(mod) + tensor_array_ops = TensorArrayOps(p, dtype) + unstack_tensor1 = tensor_array_ops.get_var('tensor_array_unstack_tensor1') + v = relay.var('v') + mod["main"] = relay.Function([v], unstack_tensor1(v)) + for kind in ["debug"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + t = np.random.uniform(size=(1,)).astype(dtype) + result = ex.evaluate()(t) + res = vmobj_to_list(result) + tvm.testing.assert_allclose(t, res) + run('float32') + run('int32') + +def test_tensor_take(): + def run(dtype): + mod = relay.Module() + p = Prelude(mod) + tensor_array_ops = TensorArrayOps(p, dtype) + take = tensor_array_ops.get_var('tensor_take') + tensor2 = tensor_array_ops.get_var('tensor2') + v = relay.var('v') + lower = relay.var('lower') + upper = relay.var('upper') + mod["main"] = relay.Function([v, lower, upper], take(tensor2(v), lower, upper)) + for kind in ["debug"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + t = np.random.uniform(size=(10, 10)).astype(dtype) + result = ex.evaluate()(t, 2, 5) + res = vmobj_to_list(result) + expected = [np.take(t, range(2, 5), axis=0)] + tvm.testing.assert_allclose(expected, res) + run('float32') + run('int32') + +def test_any_take(): + mod = relay.Module() + p = Prelude(mod) + v = relay.var('v', relay.ty.TensorType([relay.ty.Any(), relay.ty.Any()])) + lower = relay.var('lower', 'int32') + upper = relay.var('upper', 'int32') + + t1 = relay.op.take(v, relay.op.arange(lower, upper, dtype='int32'), axis=0) + + mod["main"] = relay.Function([v, lower, upper], t1) + for kind in ["debug"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + t = np.random.uniform(size=(10, 10)).astype('float32') + result = ex.evaluate()(t, 2, 5) + res = vmobj_to_list(result) + expected = [np.take(t, np.arange(2, 5), axis=0)] + tvm.testing.assert_allclose(expected, res) if __name__ == "__main__": test_nat_constructor() @@ -707,3 +873,9 @@ def test_iterate(): test_size() test_compose() test_iterate() + + test_tensor_array_add_one() + test_tensor_array_constructor() + test_tensor_array_read() + test_tensor_array_stack() + test_tensor_array_unstack() diff --git a/tests/python/relay/test_feature.py b/tests/python/relay/test_feature.py index 8f0e90de0315..64eda9d04e7c 100644 --- a/tests/python/relay/test_feature.py +++ b/tests/python/relay/test_feature.py @@ -38,7 +38,8 @@ def test_prelude(): Feature.fLet, Feature.fIf, Feature.fConstructor, - Feature.fMatch + Feature.fMatch, + Feature.fGraph ]) From e62393e683f90a08584c06cf62d5fd6a35546b23 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Wed, 16 Oct 2019 11:25:02 -0700 Subject: [PATCH 02/12] rename --- python/tvm/relay/prelude.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 6cdc5981ef9b..821ff596a269 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -24,7 +24,7 @@ from . import op from .module import Module -class TensorArrayOps: +class TensorArrayOps(object): """Contains tensor array related ops""" def __init__(self, prelude, dtype): @@ -141,14 +141,14 @@ def define_tensor_take(self): tensor6_case], False), tensor_t(), []) - def define_tensor_add_one(self): - """Defines a function to grow a tensor_t's rank by adding one dimention in front + def define_tensor_expand_dims(self): + """Defines a function to grow a tensor_t's rank by adding one dimension in front of the original tensor_t. - tensor_add_one(t) : tensor_t -> tensor_t + tensor_expand_dims(t) : tensor_t -> tensor_t """ - add_one_name = self.get_name("tensor_add_one") - add_one_var = GlobalVar(add_one_name) - setattr(self.prelude, add_one_name, add_one_var) + expand_dims_name = self.get_name("tensor_expand_dims") + expand_dims_var = GlobalVar(expand_dims_name) + setattr(self.prelude, expand_dims_name, expand_dims_var) tensor_type_var = self.get_var('tensor_t') x = Var("x", tensor_type_var()) t0 = Var("t0") @@ -176,7 +176,7 @@ def define_tensor_add_one(self): tensor5_var(op.expand_dims(t4, 0, 1))) tensor5_case = Clause(PatternConstructor(tensor5_var, [PatternVar(t5)]), tensor6_var(op.expand_dims(t5, 0, 1))) - self.prelude.mod[add_one_var] =\ + self.prelude.mod[expand_dims_var] =\ Function([x], Match(x, [tensor0_case, tensor1_case, @@ -340,7 +340,7 @@ def define_tensor_array_unstack_tensor2(self): self.prelude.l(self.get_var('tensor_t')()), []) def define_tensor_array_scatter(self): - """Defines a function to scatter the values of a tensor_t in indices of a tensor array.. + """Defines a function to scatter the values of a tensor_t in indices of a tensor array. tensor_array_scatter(ta, indices, value) : list[tensor_t] -> Tensor[(Any), int32] -> tensor_t -> list[tensor_t] """ @@ -500,19 +500,19 @@ def define_tensor_array_stack(self): setattr(self.prelude, stack_name, stack_var) tensor_type_var = self.get_var('tensor_t') tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) - add_one_var = self.get_var('tensor_add_one') + expand_dims_var = self.get_var('tensor_expand_dims') concat_var = self.get_var('tensor_concatenate') - tensor_array_add_one = self.prelude.map(add_one_var, tensor_array) + tensor_array_expand_dims = self.prelude.map(expand_dims_var, tensor_array) tensors = self.prelude.foldl(concat_var, - self.prelude.hd(tensor_array_add_one), - self.prelude.tl(tensor_array_add_one)) + self.prelude.hd(tensor_array_expand_dims), + self.prelude.tl(tensor_array_expand_dims)) self.prelude.mod[stack_var] = Function([tensor_array], tensors, tensor_type_var(), []) def register(self): """Register all tensor array ops in Prelude""" self.define_tensor_adt() self.define_tensor_take() - self.define_tensor_add_one() + self.define_tensor_expand_dims() self.define_tensor_concat() self.define_tensor_array() self.define_tensor_array_read() From 1529286fdeb95b8cdf5240ea61921b0694d8dfaf Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Wed, 16 Oct 2019 11:33:35 -0700 Subject: [PATCH 03/12] delete test --- tests/python/relay/test_adt.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 284a1e01b65a..37a3bb420fcb 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -832,24 +832,6 @@ def run(dtype): run('float32') run('int32') -def test_any_take(): - mod = relay.Module() - p = Prelude(mod) - v = relay.var('v', relay.ty.TensorType([relay.ty.Any(), relay.ty.Any()])) - lower = relay.var('lower', 'int32') - upper = relay.var('upper', 'int32') - - t1 = relay.op.take(v, relay.op.arange(lower, upper, dtype='int32'), axis=0) - - mod["main"] = relay.Function([v, lower, upper], t1) - for kind in ["debug"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - t = np.random.uniform(size=(10, 10)).astype('float32') - result = ex.evaluate()(t, 2, 5) - res = vmobj_to_list(result) - expected = [np.take(t, np.arange(2, 5), axis=0)] - tvm.testing.assert_allclose(expected, res) - if __name__ == "__main__": test_nat_constructor() test_double() From c73fc5b7025120e5365d916d08317e2c7d13affe Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Wed, 16 Oct 2019 12:17:56 -0700 Subject: [PATCH 04/12] Move utility function --- python/tvm/relay/prelude.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 821ff596a269..f36545a91ff6 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -34,14 +34,11 @@ def __init__(self, prelude, dtype): def get_name(self, canonical): """Get name corresponding to the caninical name""" - if canonical == 'tensor_t': - return 'tensor_{}_t'.format(self.dtype) - return "{}_{}".format(canonical, self.dtype) + return self.prelude.get_name(canonical, self.dtype) def get_var(self, canonical): """Get var corresponding to the caninical name""" - name = self.get_name(canonical) - return getattr(self.prelude, name) + return self.prelude.get_var(canonical, self.dtype) def define_tensor_adt(self): """Defines the dynamic tensor ADT, which is the container for tensors @@ -535,6 +532,17 @@ def __init__(self, mod=None): self.mod = mod self.load_prelude() + def get_name(self, canonical, dtype): + """Get name corresponding to the caninical name""" + if canonical == 'tensor_t': + return 'tensor_{}_t'.format(dtype) + return "{}_{}".format(canonical, dtype) + + def get_var(self, canonical, dtype): + """Get var corresponding to the caninical name""" + name = self.get_name(canonical, dtype) + return getattr(self, name) + def load_prelude(self): """Parses the Prelude from Relay's text format into a module.""" # TODO(@jroesch): we should remove this helper when we port over prelude From 7b458bf2687c715626d1da171c36f7dbc20b7109 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Wed, 16 Oct 2019 16:12:27 -0700 Subject: [PATCH 05/12] Refactor --- python/tvm/relay/prelude.py | 4 ++-- tests/python/relay/test_adt.py | 11 +++++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index f36545a91ff6..d27ffe512617 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -533,13 +533,13 @@ def __init__(self, mod=None): self.load_prelude() def get_name(self, canonical, dtype): - """Get name corresponding to the caninical name""" + """Get name corresponding to the canonical name""" if canonical == 'tensor_t': return 'tensor_{}_t'.format(dtype) return "{}_{}".format(canonical, dtype) def get_var(self, canonical, dtype): - """Get var corresponding to the caninical name""" + """Get var corresponding to the canonical name""" name = self.get_name(canonical, dtype) return getattr(self, name) diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 37a3bb420fcb..5f7a467a65a2 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -685,15 +685,14 @@ def test_iterate(): res = intrp.evaluate(relay.Function([], expr)()) assert count(res) == 12 -def test_tensor_array_add_one(): +def test_tensor_expand_dims(): def run(dtype): x = relay.var('x') mod = relay.Module() p = Prelude(mod) - tensor_array_ops = TensorArrayOps(p, dtype) - add_one_func = tensor_array_ops.get_var('tensor_add_one') - tensor1 = tensor_array_ops.get_var('tensor1') - mod["main"] = relay.Function([x], add_one_func(tensor1(x))) + expand_dims_func = p.get_var('tensor_expand_dims', dtype) + tensor1 = p.get_var('tensor1', dtype) + mod["main"] = relay.Function([x], expand_dims_func(tensor1(x))) for kind in ["debug"]: ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") x_np = np.random.uniform(size=(1,)).astype(dtype) @@ -856,7 +855,7 @@ def run(dtype): test_compose() test_iterate() - test_tensor_array_add_one() + test_tensor_expand_dims() test_tensor_array_constructor() test_tensor_array_read() test_tensor_array_stack() From 837b1e5d193012f4b926d3b2058afc48838eed2d Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Wed, 16 Oct 2019 16:44:36 -0700 Subject: [PATCH 06/12] fix tensor array ops --- tests/python/relay/test_adt.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 5f7a467a65a2..48a233fb8e14 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -18,7 +18,7 @@ from tvm import relay from tvm.relay.backend.interpreter import ConstructorValue from tvm.relay import create_executor -from tvm.relay.prelude import Prelude, TensorArrayOps +from tvm.relay.prelude import Prelude from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr import numpy as np @@ -708,8 +708,7 @@ def run(dtype): x = relay.var('x') mod = relay.Module() p = Prelude(mod) - tensor_array_ops = TensorArrayOps(p, dtype) - tensor_array = tensor_array_ops.get_var('tensor_array') + tensor_array = p.get_var('tensor_array', dtype) mod["main"] = relay.Function([x], tensor_array(x)) for kind in ["debug"]: ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") @@ -724,11 +723,10 @@ def test_tensor_array_read(): def run(dtype): mod = relay.Module() p = Prelude(mod) - tensor_array_ops = TensorArrayOps(p, dtype) l = relay.var('l') i = relay.var('i') - read_func = tensor_array_ops.get_var('tensor_array_read') - tensor_array = tensor_array_ops.get_var('tensor_array') + read_func = p.get_var('tensor_array_read', dtype) + tensor_array = p.get_var('tensor_array', dtype) mod["main"] = relay.Function([l, i], read_func(tensor_array(l), i)) for kind in ["debug"]: ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") @@ -770,11 +768,10 @@ def test_tensor_array_stack(): def run(dtype): mod = relay.Module() p = Prelude(mod) - tensor_array_ops = TensorArrayOps(p, dtype) - tensor_array = tensor_array_ops.get_var('tensor_array') - tensor1 = tensor_array_ops.get_var('tensor1') - write = tensor_array_ops.get_var('tensor_array_write') - stack = tensor_array_ops.get_var('tensor_array_stack') + tensor_array = p.get_var('tensor_array', dtype) + tensor1 = p.get_var('tensor1', dtype) + write = p.get_var('tensor_array_write', dtype) + stack = p.get_var('tensor_array_stack', dtype) l = relay.var('l') v = relay.var('v') init_tensor_array = tensor_array(relay.const(3)) @@ -797,8 +794,7 @@ def test_tensor_array_unstack(): def run(dtype): mod = relay.Module() p = Prelude(mod) - tensor_array_ops = TensorArrayOps(p, dtype) - unstack_tensor1 = tensor_array_ops.get_var('tensor_array_unstack_tensor1') + unstack_tensor1 = p.get_var('tensor_array_unstack_tensor1', dtype) v = relay.var('v') mod["main"] = relay.Function([v], unstack_tensor1(v)) for kind in ["debug"]: @@ -814,9 +810,8 @@ def test_tensor_take(): def run(dtype): mod = relay.Module() p = Prelude(mod) - tensor_array_ops = TensorArrayOps(p, dtype) - take = tensor_array_ops.get_var('tensor_take') - tensor2 = tensor_array_ops.get_var('tensor2') + take = p.get_var('tensor_take', dtype) + tensor2 = p.get_var('tensor2', dtype) v = relay.var('v') lower = relay.var('lower') upper = relay.var('upper') From eb938ad036155400ff3dd574ced9f0de54867f8a Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Wed, 16 Oct 2019 19:38:12 -0700 Subject: [PATCH 07/12] fix test --- tests/python/relay/test_adt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 48a233fb8e14..390d3cd9f3c4 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -738,11 +738,11 @@ def run(dtype): run('int32') def vmobj_to_list(o): - if isinstance(o, tvm.relay.backend.vmobj.TensorObject): + if isinstance(o, tvm.relay.backend.vmobj.Tensor): return [o.asnumpy().tolist()] elif isinstance(o, tvm.relay.backend.interpreter.TensorValue): return [o.asnumpy()] - elif isinstance(o, tvm.relay.backend.vmobj.DatatypeObject): + elif isinstance(o, tvm.relay.backend.vmobj.Datatype): result = [] for f in o: result.extend(vmobj_to_list(f)) From 9a45b387aaefd3e6aafc76db752a3ed041dbec49 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Thu, 17 Oct 2019 14:05:27 -0700 Subject: [PATCH 08/12] fix rebase --- src/relay/backend/vm/serializer.cc | 439 ----------------------------- 1 file changed, 439 deletions(-) delete mode 100644 src/relay/backend/vm/serializer.cc diff --git a/src/relay/backend/vm/serializer.cc b/src/relay/backend/vm/serializer.cc deleted file mode 100644 index a3d882818a89..000000000000 --- a/src/relay/backend/vm/serializer.cc +++ /dev/null @@ -1,439 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2019 by Contributors - * \file src/relay/backend/vm/serializer.cc - * \brief Implementation of serializing APIs for the Relay VM. - */ -#include "serializer.h" - -#include -#include - -#include -#include -#include -#include -#include - -#include "serialize_util.h" - -namespace tvm { -namespace relay { -namespace vm { - -void Serializer::Init(const VirtualMachine* vm) { - vm_ = vm; - // Initialize the stream object. - strm_ = new dmlc::MemoryStringStream(&code_); -} - -runtime::PackedFunc Serializer::GetFunction( - const std::string& name, - const std::shared_ptr& sptr_to_self) { - if (name == "get_lib") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetLib(); - }); - } else if (name == "get_primitive_ops") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetPrimitiveOps(); - }); - } else if (name == "get_bytecode") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetBytecode(); - }); - } else if (name == "get_globals") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetGlobals(); - }); - } else if (name == "get_stats") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->Stats(); - }); - } else if (name == "serialize") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->Serialize(); - }); - } else { - LOG(FATAL) << "Unknown packed function: " << name; - return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); - } -} - -tvm::Array Serializer::GetPrimitiveOps() const { - std::vector ret; - for (const auto& it : vm_->primitive_map) { - auto packed_name = tvm::ir::StringImm::make(it.first); - auto packed_index = static_cast(it.second); - if (ret.size() <= packed_index) { - ret.resize(packed_index + 1); - } - ret[packed_index] = packed_name; - } - return ret; -} - -std::string Serializer::Stats() const { - std::ostringstream oss; - oss << "Relay VM statistics:" << std::endl; - - // Get the number of constants and the shape of each of them. - oss << " Constant shapes (# " << vm_->constants.size() << "): ["; - for (const auto& it : vm_->constants) { - auto* cell = it.as(); - CHECK(cell != nullptr); - runtime::NDArray data = cell->data; - const auto& shape = data.Shape(); - - // Scalar - if (shape.empty()) { - oss << "scalar, "; - continue; - } - - oss << "["; - for (auto s : shape) { - oss << s << ", "; - } - oss.seekp(-2, oss.cur); - oss << "], " << std::endl; - } - if (!vm_->constants.empty()) oss.seekp(-2, oss.cur); - oss << "]" << std::endl; - - // Get the number of globals and the name of each of them. - oss << " Globals (#" << vm_->global_map.size() << "): ["; - for (const auto& it : vm_->global_map) { - oss << "(\"" << it.first << "\", " << it.second << ")" << ", "; - } - if (!vm_->global_map.empty()) oss.seekp(-2, oss.cur); - oss << "]" << std::endl; - - // Get the number of primitive ops and the name of each of them. - oss << " Primitive ops (#" << vm_->primitive_map.size() << "): ["; - const auto& prim_ops = GetPrimitiveOps(); - for (const auto& it : prim_ops) { - oss << it << ", "; - } - if (!prim_ops.empty()) oss.seekp(-2, oss.cur); - oss << "]" << std::endl; - - return oss.str(); -} - -TVMByteArray Serializer::Serialize() { - uint64_t header = kTVMVMBytecodeMagic; - strm_->Write(header); - std::string version = TVM_VERSION; - strm_->Write(version); - - // Global section. - SerializeGlobalSection(); - - // Constant section. - SerializeConstantSection(); - - // Primitive names. - SerializePrimitiveOpNames(); - - // Code section. - SerializeCodeSection(); - - TVMByteArray arr; - arr.data = code_.c_str(); - arr.size = code_.length(); - return arr; -} - -void Serializer::SerializeGlobalSection() { - auto globals = GetGlobals(); - std::vector glbs; - for (const auto& it : globals) { - glbs.push_back(it.as()->value); - } - strm_->Write(glbs); -} - -void Serializer::SerializeConstantSection() { - std::vector arrays; - for (const auto& obj : vm_->constants) { - const auto* cell = obj.as(); - CHECK(cell != nullptr); - runtime::NDArray data = cell->data; - arrays.push_back(const_cast(data.operator->())); - } - strm_->Write(static_cast(vm_->constants.size())); - for (const auto& it : arrays) { - runtime::SaveDLTensor(strm_, it); - } -} - -void Serializer::SerializePrimitiveOpNames() { - auto names = GetPrimitiveOps(); - std::vector primitive_names; - for (const auto& it : names) { - primitive_names.push_back(it.as()->value); - } - strm_->Write(primitive_names); -} - -// Serialize a virtual machine instruction. It creates a list that contains the -// hash, opcode, and all fields of an instruction. -// -// For example, the function signature used to create an `AllocTensor` -// instruction is: -// Instruction AllocTensor(std::vector shape, DLDataType dtype, RegName dst) -// -// The serialized form will be: -// `hash 5 dtype.code dtype.bits dtype.lanes ndim dst_register val1 val2 ... valn` -// -// where hash is the hash of serialized instruction that is computed internally -// by the `VMInstructionSerializer`. It is used for sanity check before decoding. -// 5 shows opcode of `AllocTensor`, `(dtype.code dtype.bits dtype.lanes)` -// represents a `DLDataType`, `ndim` is the number of dimensions, `dst_register` -// is the destination register, and the rest of it together indicates the shape -// of the tensor to be allocated. -VMInstructionSerializer SerializeInstruction(const Instruction& instr) { - std::vector fields; - // Save the opcode. - DLOG(INFO) << "Serializing: " << instr << std::endl; - switch (instr.op) { - case Opcode::Move: { - // Number of fields = 2 - fields.assign({instr.from, instr.dst}); - break; - } - case Opcode::Ret: { - // Number of fields = 1 - fields.push_back(instr.result); - break; - } - case Opcode::Fatal: { - // Number of fields = 0 - break; - } - case Opcode::InvokePacked: { - // Number of fields = 3 + instr.arity - // Note that arity includes both input arguments and outputs. We will - // put all the `arity` number of fields in the end for serialization. - fields.assign({instr.packed_index, instr.arity, instr.output_size}); - // Save the args. - fields.insert(fields.end(), instr.packed_args, instr.packed_args + instr.arity); - break; - } - case Opcode::AllocTensor: { - // Number of fields = 5 + instr.alloc_tensor.ndim - // Save `DLDataType` and the dst register. - const auto& dtype = instr.alloc_tensor.dtype; - fields.assign({dtype.code, dtype.bits, dtype.lanes}); - - // The number of dimensions is not needed for constructing an - // `AllocTensor` instruction as it equals to the length of the `shape` - // vector. However, we save it to conveniently deserialize the instruction - // because we will know how many fields are needed by the `shape` argument. - fields.push_back(instr.alloc_tensor.ndim); - fields.push_back(instr.dst); - - // Save the shape of the tensor. - // Note that this field is rotated to the end of the list. - fields.insert(fields.end(), instr.alloc_tensor.shape, - instr.alloc_tensor.shape + instr.alloc_tensor.ndim); - break; - } - case Opcode::AllocTensorReg: { - // Number of fields = 5 - fields.push_back(instr.alloc_tensor_reg.shape_register); - // Save `DLDataType` and the dst register. - const auto& dtype = instr.alloc_tensor.dtype; - fields.insert(fields.end(), {dtype.code, dtype.bits, dtype.lanes}); - fields.push_back(instr.dst); - break; - } - case Opcode::AllocDatatype: { - // Number of fields = 3 + instr.num_fields - fields.assign({instr.constructor_tag, instr.num_fields, instr.dst}); - - // Save the fields. - fields.insert(fields.end(), instr.datatype_fields, - instr.datatype_fields + instr.num_fields); - break; - } - case Opcode::AllocClosure: { - // Number of fields = 3 + instr.num_freevar - fields.assign({instr.clo_index, instr.num_freevar, instr.dst}); - - // Save the free vars. - fields.insert(fields.end(), instr.free_vars, - instr.free_vars + instr.num_freevar); - break; - } - case Opcode::If: { - // Number of fields = 4 - fields.assign({instr.if_op.test, - instr.if_op.target, - instr.if_op.true_offset, - instr.if_op.false_offset}); - break; - } - case Opcode::Invoke: { - // Number of fields = 3 + instr.num_args - fields.assign({instr.func_index, instr.num_args, instr.dst}); - - // Save the args. - fields.insert(fields.end(), instr.invoke_args_registers, - instr.invoke_args_registers + instr.num_args); - break; - } - case Opcode::InvokeClosure: { - // Number of fields = 3 + instr.num_closure_args - fields.assign({instr.closure, instr.num_closure_args, instr.dst}); - - // Save the args. - fields.insert(fields.end(), instr.closure_args, - instr.closure_args + instr.num_closure_args); - break; - } - case Opcode::LoadConst: { - // Number of fields = 2 - fields.assign({instr.const_index, instr.dst}); - break; - } - case Opcode::LoadConsti: { - // Number of fields = 2 - fields.assign({instr.load_consti.val, instr.dst}); - break; - } - case Opcode::GetField: { - // Number of fields = 3 - fields.assign({instr.object, instr.field_index, instr.dst}); - break; - } - case Opcode::GetTag: { - // Number of fields = 2 - fields.assign({instr.get_tag.object, instr.dst}); - break; - } - case Opcode::Goto: { - // Number of fields = 1 - fields.push_back(instr.pc_offset); - break; - } - default: - LOG(FATAL) << "Invalid opcode" << static_cast(instr.op); - break; - } - - return VMInstructionSerializer(static_cast(instr.op), fields); -} - -void Serializer::SerializeCodeSection() { - // Save the number of functions. - strm_->Write(static_cast(vm_->functions.size())); - for (const auto& func : vm_->functions) { - // Serialize the function info. - VMFunctionSerializer func_format(func.name, - func.register_file_size, - func.instructions.size(), - func.params); - func_format.Save(strm_); - - // Serialize each instruction. - for (const auto& instr : func.instructions) { - const auto& serialized_instr = SerializeInstruction(instr); - serialized_instr.Save(strm_); - } - } -} - -tvm::Array Serializer::GetGlobals() const { - tvm::Array ret; - std::vector > globals(vm_->global_map.begin(), - vm_->global_map.end()); - auto comp = [](const std::pair& a, - const std::pair& b) { - return a.second < b.second; - }; - std::sort(globals.begin(), globals.end(), comp); - for (const auto& it : globals) { - ret.push_back(tvm::ir::StringImm::make(it.first)); - } - return ret; -} - -std::string Serializer::GetBytecode() const { - std::ostringstream oss; - - for (const auto& func : vm_->functions) { - // Print the header of the function format. - oss << "# func name, reg file size, param count, inst count:" - << std::endl; - oss << func.name << " " - << func.register_file_size << " " - << func.params.size() << " " - << func.instructions.size() << std::endl; - - // Print pramams of a `VMFunction`. - oss << "# Parameters:"<< std::endl; - for (const auto& param : func.params) { - oss << param << " "; - } - oss << std::endl; - - // Print the instructions of a `VMFunction`. - // The part after ";" is the instruction in text format. - oss << "hash, opcode, fields # inst(text):"<< std::endl; - for (const auto& instr : func.instructions) { - const auto& serialized_instr = SerializeInstruction(instr); - oss << std::hex << "0x" << serialized_instr.Hash() << " " - << std::dec << serialized_instr.opcode << " "; - for (auto it : serialized_instr.fields) { - oss << it << " "; - } - oss << " # " << instr; - if (oss.str().back() != '\n') oss << std::endl; - } - } - - return oss.str(); -} - -runtime::Module Serializer::GetLib() const { - return vm_->lib; -} - -runtime::Module CreateSerializer(const VirtualMachine* vm) { - std::shared_ptr exec = std::make_shared(); - exec->Init(vm); - return runtime::Module(exec); -} - -TVM_REGISTER_GLOBAL("relay._vm._Serializer") -.set_body([](TVMArgs args, TVMRetValue* rv) { - runtime::Module mod = args[0]; - const auto* vm = dynamic_cast(mod.operator->()); - CHECK(vm) << "Virtual machine has not been defined yet." - << "\n"; - *rv = CreateSerializer(vm); -}); - -} // namespace vm -} // namespace relay -} // namespace tvm From 3ed9437360ca0a284fbb2f6532fdb840ce814f16 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Thu, 17 Oct 2019 15:21:15 -0700 Subject: [PATCH 09/12] Fix serializer bug --- src/runtime/vm/executable.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 21f71af4eb8c..f85283094e91 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -309,7 +309,9 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { fields.push_back(instr.alloc_tensor_reg.shape_register); // Save `DLDataType` and the dst register. const auto& dtype = instr.alloc_tensor.dtype; - fields.assign({dtype.code, dtype.bits, dtype.lanes}); + fields.push_back(dtype.code); + fields.push_back(dtype.bits); + fields.push_back(dtype.lanes); fields.push_back(instr.dst); break; } From 1bc2ed01db3d284e90847eb6c0b6d7ee10e623c8 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Thu, 17 Oct 2019 15:31:32 -0700 Subject: [PATCH 10/12] Improve tf convert name lookup to use prelude api --- python/tvm/relay/frontend/tensorflow.py | 34 +++++++++---------------- 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 9a7b471e9095..4dedcfb55064 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -515,7 +515,7 @@ def _impl(inputs, attr, params): def _tensor_array(): def _impl(inputs, attr, params, prelude): dtype_str = attr.get('dtype').name - tensor_array_constructor = getattr(prelude, "tensor_array_{}".format(dtype_str)) + tensor_array_constructor = prelude.get_var('tensor_array', dtype_str) return tensor_array_constructor(_op.take(inputs[0], tvm.relay.const(0))) return _impl @@ -523,13 +523,10 @@ def _tensor_array_scatter(): def _impl(inputs, attr, params, prelude): dtype_str = attr.get('T').name values_rank = len(inputs[2].type_annotation.shape) - unstack_function_name = "tensor_array_unstack_tensor{}_{}".format(values_rank, dtype_str) - - values = getattr(prelude, unstack_function_name)(inputs[2]) - - tensor_array_scatter_name = "tensor_array_scatter_{}".format(dtype_str) - tensor_array_scatter_var = getattr(prelude, tensor_array_scatter_name) - return tensor_array_scatter_var(inputs[0], inputs[1], values) + unstack_function = prelude.get_var("tensor_array_unstack_tensor{}".format(values_rank), dtype_str) + values = unstack_function(inputs[2]) + tensor_array_scatter_func = prelude.get_var('tensor_array_scatter', dtype_str) + return tensor_array_scatter_func(inputs[0], inputs[1], values) return _impl def _tensor_array_gather(): @@ -547,20 +544,17 @@ def _impl(inputs, attr, params, prelude): input_rank = len(inputs[2].type_annotation.shape) dtype = attr.get('T').name - tensor_name = 'tensor{}_{}'.format(input_rank, dtype) - tensor_func = getattr(prelude, tensor_name) + tensor_name = 'tensor{}'.format(input_rank) + tensor_func = prelude.get_var(tensor_name, dtype) v = tensor_func(inputs[2]) - - write_name = 'tensor_array_write_{}'.format(dtype) - write_func = getattr(prelude, write_name) + write_func = prelude.get_var('tensor_array_write', dtype) return write_func(inputs[3], _op.take(inputs[1], tvm.relay.const(0)), v) return _impl def _tensor_array_read(): def _impl(inputs, attr, params, prelude): - read_name = 'tensor_array_read_{}'.format(attr.get('dtype').name) - read_func = getattr(prelude, read_name) + read_func = prelude.get_var('tensor_array_read', attr.get('dtype').name) return read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0))) return _impl @@ -568,19 +562,15 @@ def _tensor_array_split(): def _impl(inputs, attr, params, prelude): input_rank = len(inputs[1].type_annotation.shape) dtype_str = attr.get('T').name - tensor_constructor_name = "tensor{}_{}".format(input_rank, dtype_str) - v = getattr(prelude, tensor_constructor_name)(inputs[1]) + v = prelude.get_var("tensor{}".format(input_rank), dtype_str) lengths = _op.cast(inputs[2], 'int32') - - split_name = "tensor_array_split_{}".format(dtype_str) - split_var = getattr(prelude, split_name) + split_var = prelude.get_var('tensor_array_split', dtype_str) return split_var(inputs[0], v, lengths) return _impl def _tensor_array_concat(): def _impl(inputs, attr, params, prelude): - concat_name = 'tensor_array_concat_{}'.format(attr['dtype'].name) - concat_func = getattr(prelude, concat_name) + concat_func = prelude.get_var('tensor_array_concat', attr['dtype'].name) return concat_func(inputs[1]) return _impl From c421b4ab347e1bae48d5c78d158d2016744436f3 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Thu, 17 Oct 2019 15:34:25 -0700 Subject: [PATCH 11/12] Fix lint --- python/tvm/relay/frontend/tensorflow.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 4dedcfb55064..749894454b3d 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -523,7 +523,8 @@ def _tensor_array_scatter(): def _impl(inputs, attr, params, prelude): dtype_str = attr.get('T').name values_rank = len(inputs[2].type_annotation.shape) - unstack_function = prelude.get_var("tensor_array_unstack_tensor{}".format(values_rank), dtype_str) + unstack_name = "tensor_array_unstack_tensor{}".format(values_rank) + unstack_function = prelude.get_var(unstack_name, dtype_str) values = unstack_function(inputs[2]) tensor_array_scatter_func = prelude.get_var('tensor_array_scatter', dtype_str) return tensor_array_scatter_func(inputs[0], inputs[1], values) From 4b4c51b58ffe073a6b6ca441cc1ad2dca38887b2 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Thu, 17 Oct 2019 20:05:36 -0700 Subject: [PATCH 12/12] Fix test --- python/tvm/relay/frontend/tensorflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 749894454b3d..eb67cf24b81e 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -563,7 +563,7 @@ def _tensor_array_split(): def _impl(inputs, attr, params, prelude): input_rank = len(inputs[1].type_annotation.shape) dtype_str = attr.get('T').name - v = prelude.get_var("tensor{}".format(input_rank), dtype_str) + v = prelude.get_var("tensor{}".format(input_rank), dtype_str)(inputs[1]) lengths = _op.cast(inputs[2], 'int32') split_var = prelude.get_var('tensor_array_split', dtype_str) return split_var(inputs[0], v, lengths)