Skip to content

Commit

Permalink
[Relay][Frontend][TF] Add tensor array ops (#3798)
Browse files Browse the repository at this point in the history
* [Relay][Frontend][TF] Add tensor array ops

* rename

* delete test

* Move utility function

* Refactor

* fix tensor array ops

* fix test

* fix rebase

* Fix serializer bug

* Improve tf convert name lookup to use prelude api

* Fix lint

* Fix test
  • Loading branch information
wweic authored and zhiics committed Oct 18, 2019
1 parent 4052de6 commit 36a9677
Show file tree
Hide file tree
Showing 8 changed files with 899 additions and 10 deletions.
82 changes: 81 additions & 1 deletion python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@

import warnings
from collections import defaultdict

# Numpy support
import numpy as np

import tvm

from tvm.relay.prelude import Prelude

from .. import analysis
from .. import expr as _expr
from .. import op as _op
Expand Down Expand Up @@ -508,6 +512,69 @@ def _impl(inputs, attr, params):
return _op.concatenate(inputs_reshaped, axis)
return _impl

def _tensor_array():
def _impl(inputs, attr, params, prelude):
dtype_str = attr.get('dtype').name
tensor_array_constructor = prelude.get_var('tensor_array', dtype_str)
return tensor_array_constructor(_op.take(inputs[0], tvm.relay.const(0)))
return _impl

def _tensor_array_scatter():
def _impl(inputs, attr, params, prelude):
dtype_str = attr.get('T').name
values_rank = len(inputs[2].type_annotation.shape)
unstack_name = "tensor_array_unstack_tensor{}".format(values_rank)
unstack_function = prelude.get_var(unstack_name, dtype_str)
values = unstack_function(inputs[2])
tensor_array_scatter_func = prelude.get_var('tensor_array_scatter', dtype_str)
return tensor_array_scatter_func(inputs[0], inputs[1], values)
return _impl

def _tensor_array_gather():
def _impl(inputs, attr, params, prelude):
return prelude.tensor_array_gather(inputs[2], inputs[1])
return _impl

def _tensor_array_size():
def _impl(inputs, attr, params, prelude):
return prelude.length(inputs[0])
return _impl

def _tensor_array_write():
def _impl(inputs, attr, params, prelude):
input_rank = len(inputs[2].type_annotation.shape)
dtype = attr.get('T').name

tensor_name = 'tensor{}'.format(input_rank)
tensor_func = prelude.get_var(tensor_name, dtype)
v = tensor_func(inputs[2])
write_func = prelude.get_var('tensor_array_write', dtype)

return write_func(inputs[3], _op.take(inputs[1], tvm.relay.const(0)), v)
return _impl

def _tensor_array_read():
def _impl(inputs, attr, params, prelude):
read_func = prelude.get_var('tensor_array_read', attr.get('dtype').name)
return read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0)))
return _impl

def _tensor_array_split():
def _impl(inputs, attr, params, prelude):
input_rank = len(inputs[1].type_annotation.shape)
dtype_str = attr.get('T').name
v = prelude.get_var("tensor{}".format(input_rank), dtype_str)(inputs[1])
lengths = _op.cast(inputs[2], 'int32')
split_var = prelude.get_var('tensor_array_split', dtype_str)
return split_var(inputs[0], v, lengths)
return _impl

def _tensor_array_concat():
def _impl(inputs, attr, params, prelude):
concat_func = prelude.get_var('tensor_array_concat', attr['dtype'].name)
return concat_func(inputs[1])
return _impl

def _tile():
def _impl(inputs, attr, params):
reps = _get_list_param(params, inputs.pop())
Expand Down Expand Up @@ -1313,6 +1380,14 @@ def _impl(inputs, attr, params):
'NotEqual' : _broadcast('not_equal'),
'OneHot' : _one_hot(),
'Pack' : _pack(),
'TensorArrayV3' : _tensor_array(),
'TensorArrayScatterV3' : _tensor_array_scatter(),
'TensorArrayGatherV3' : _tensor_array_gather(),
'TensorArraySizeV3' : _tensor_array_size(),
'TensorArrayWriteV3' : _tensor_array_write(),
'TensorArrayReadV3' : _tensor_array_read(),
'TensorArraySplitV3' : _tensor_array_split(),
'TensorArrayConcatV3' : _tensor_array_concat(),
'Pad' : _pad('Pad'),
'PadV2' : _pad('PadV2'),
'Pow' : _elemwise('power'),
Expand Down Expand Up @@ -1860,6 +1935,7 @@ def __init__(self):
self._loops = {}
self._branches = {}
self._mod = _module.Module({})
self._prelude = Prelude(self._mod)

def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct relay nodes from tensorflow graph definition - GraphDef.
Expand Down Expand Up @@ -2335,7 +2411,11 @@ def _convert_operator(self, op_name, inputs, attrs,
if op_name in identity_list:
sym = get_relay_op(op_name)(*inputs, **attrs)
elif op_name in convert_map:
sym = convert_map[op_name](inputs, attrs, self._params)
if 'TensorArray' in op_name:
sym = convert_map[op_name](inputs, attrs, self._params, self._prelude)
else:
sym = convert_map[op_name](inputs, attrs, self._params)

elif op_name in convert_map_rnn:
sym = self._convert_rnn_operator(op_name, inputs, attrs,
self._params, graph,
Expand Down
26 changes: 26 additions & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,29 @@ def clip_compute(attrs, inputs, output_type, target):

register_schedule("clip", schedule_elemwise)

@script
def _cast_shape_function(x):
out_ndim = len(x)
out = output_tensor((out_ndim,), "int64")
for i in const_range(out_ndim):
out[i] = x[i]
return out

def cast_shape_func(attrs, inputs, out_ndims):
return [_cast_shape_function(*inputs)]

@script
def _expand_dims_shape_func(x):
ndim = len(x.shape)
out = output_tensor((ndim+1,), "int64")
out[0] = int64(1)
for i in const_range(0, ndim):
out[i+1] = int64(x.shape[i])
return out

def expand_dims_shape_func(attrs, inputs, out_ndims):
return [_expand_dims_shape_func(*inputs)]

# shape func
@script
def _broadcast_shape_func(x, y, ndim):
Expand Down Expand Up @@ -140,6 +163,9 @@ def _broadcast_shape_func(x, y, ndim):
def broadcast_shape_func(attrs, inputs, out_ndims):
return [_broadcast_shape_func(*inputs, out_ndims[0])]

register_shape_func("expand_dims", False, expand_dims_shape_func)
register_shape_func("cast", False, cast_shape_func)

register_shape_func("add", False, broadcast_shape_func)
register_shape_func("subtract", False, broadcast_shape_func)
register_shape_func("multiply", False, broadcast_shape_func)
Expand Down
Loading

0 comments on commit 36a9677

Please sign in to comment.