Skip to content

Commit

Permalink
[Relay][Frontend][TF] Add tensor array ops
Browse files Browse the repository at this point in the history
  • Loading branch information
wweic committed Sep 15, 2019
1 parent 4b431c6 commit c3f80b8
Show file tree
Hide file tree
Showing 8 changed files with 771 additions and 30 deletions.
73 changes: 72 additions & 1 deletion python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
import numpy as np

import tvm

from tvm.relay.prelude import Prelude

from .. import analysis
from .. import expr as _expr
from .. import op as _op
Expand Down Expand Up @@ -505,6 +508,61 @@ def _impl(inputs, attr, params):
return _op.concatenate(inputs_reshaped, axis)
return _impl

def _tensor_array():
def _impl(inputs, attr, params, prelude):
return prelude.tensor_array(_op.take(inputs[0], tvm.relay.const(0)))
return _impl

def _tensor_array_scatter():
def _impl(inputs, attr, params, prelude):
values = None
if len(inputs[2].type_annotation.shape) == 1:
pass
elif len(inputs[2].type_annotation.shape) == 2:
values = prelude.tensor_array_unstack_tensor2(inputs[2])

return prelude.tensor_array_scatter(inputs[0], inputs[1], values)
return _impl

def _tensor_array_gather():
def _impl(inputs, attr, params, prelude):
return prelude.tensor_array_gather(inputs[2], inputs[1])
return _impl

def _tensor_array_size():
def _impl(inputs, attr, params, prelude):
return prelude.tensor_array_size(inputs[0])
return _impl

def _tensor_array_write():
def _impl(inputs, attr, params, prelude):
if len(inputs[2].type_annotation.shape) == 2:
v = prelude.tensor2(inputs[2])
elif len(inputs[2].type_annotation.shape) == 3:
v = prelude.tensor3(inputs[2])
return prelude.tensor_array_write(inputs[3], _op.take(inputs[1], tvm.relay.const(0)), v)
return _impl

def _tensor_array_read():
def _impl(inputs, attr, params, prelude):
return prelude.tensor_array_read(inputs[2], _op.take(inputs[1], tvm.relay.const(0)))
return _impl

def _tensor_array_split():
def _impl(inputs, attr, params, prelude):
if len(inputs[1].type_annotation.shape) == 2:
v = prelude.tensor2(inputs[1])
elif len(inputs[1].type_annotation.shape) == 3:
v = prelude.tensor3(inputs[1])
lengths = _op.cast(inputs[2], 'int32')
return prelude.tensor_array_split(inputs[0], v, lengths)
return _impl

def _tensor_array_concat():
def _impl(inputs, attr, params, prelude):
return prelude.tensor_array_concat(inputs[1])
return _impl

def _tile():
def _impl(inputs, attr, params):
reps = _get_list_param(params, inputs.pop())
Expand Down Expand Up @@ -1302,6 +1360,14 @@ def _impl(inputs, attr, params):
'NotEqual' : _broadcast('not_equal'),
'OneHot' : _one_hot(),
'Pack' : _pack(),
'TensorArrayV3' : _tensor_array(),
'TensorArrayScatterV3' : _tensor_array_scatter(),
'TensorArrayGatherV3' : _tensor_array_gather(),
'TensorArraySizeV3' : _tensor_array_size(),
'TensorArrayWriteV3' : _tensor_array_write(),
'TensorArrayReadV3' : _tensor_array_read(),
'TensorArraySplitV3' : _tensor_array_split(),
'TensorArrayConcatV3' : _tensor_array_concat(),
'Pad' : _pad('Pad'),
'PadV2' : _pad('PadV2'),
'Pow' : _elemwise('power'),
Expand Down Expand Up @@ -1847,6 +1913,7 @@ def __init__(self):
self._loops = {}
self._branches = {}
self._mod = _module.Module({})
self._prelude = Prelude(self._mod)

def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct relay nodes from tensorflow graph definition - GraphDef.
Expand Down Expand Up @@ -2322,7 +2389,11 @@ def _convert_operator(self, op_name, inputs, attrs,
if op_name in identity_list:
sym = get_relay_op(op_name)(*inputs, **attrs)
elif op_name in convert_map:
sym = convert_map[op_name](inputs, attrs, self._params)
if 'TensorArray' in op_name:
sym = convert_map[op_name](inputs, attrs, self._params, self._prelude)
else:
sym = convert_map[op_name](inputs, attrs, self._params)

elif op_name in convert_map_rnn:
sym = self._convert_rnn_operator(op_name, inputs, attrs,
self._params, graph,
Expand Down
26 changes: 26 additions & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,29 @@ def clip_compute(attrs, inputs, output_type, target):

register_schedule("clip", schedule_elemwise)

@script
def _cast_shape_function(x):
out_ndim = len(x)
out = output_tensor((out_ndim,), "int64")
for i in const_range(out_ndim):
out[i] = x[i]
return out

def cast_shape_func(attrs, inputs, out_ndims):
return [_cast_shape_function(*inputs)]

@script
def _expand_dims_shape_func(x):
ndim = len(x.shape)
out = output_tensor((ndim+1,), "int64")
out[0] = int64(1)
for i in const_range(0, ndim):
out[i+1] = int64(x.shape[i])
return out

def expand_dims_shape_func(attrs, inputs, out_ndims):
return [_expand_dims_shape_func(*inputs)]

# shape func
@script
def _broadcast_shape_func(x, y, ndim):
Expand Down Expand Up @@ -139,6 +162,9 @@ def _broadcast_shape_func(x, y, ndim):
def broadcast_shape_func(attrs, inputs, out_ndims):
return [_broadcast_shape_func(*inputs, out_ndims[0])]

register_shape_func("expand_dims", False, expand_dims_shape_func)
register_shape_func("cast", False, cast_shape_func)

register_shape_func("add", False, broadcast_shape_func)
register_shape_func("subtract", False, broadcast_shape_func)
register_shape_func("multiply", False, broadcast_shape_func)
Expand Down
Loading

0 comments on commit c3f80b8

Please sign in to comment.