Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Frontend][TF] Add tensor array ops #3798

Merged
merged 12 commits into from
Oct 18, 2019
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 81 additions & 1 deletion python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@

import warnings
from collections import defaultdict

# Numpy support
import numpy as np

import tvm

from tvm.relay.prelude import Prelude

from .. import analysis
from .. import expr as _expr
from .. import op as _op
Expand Down Expand Up @@ -508,6 +512,69 @@ def _impl(inputs, attr, params):
return _op.concatenate(inputs_reshaped, axis)
return _impl

def _tensor_array():
def _impl(inputs, attr, params, prelude):
dtype_str = attr.get('dtype').name
tensor_array_constructor = prelude.get_var('tensor_array', dtype_str)
return tensor_array_constructor(_op.take(inputs[0], tvm.relay.const(0)))
return _impl

def _tensor_array_scatter():
def _impl(inputs, attr, params, prelude):
dtype_str = attr.get('T').name
values_rank = len(inputs[2].type_annotation.shape)
unstack_name = "tensor_array_unstack_tensor{}".format(values_rank)
unstack_function = prelude.get_var(unstack_name, dtype_str)
values = unstack_function(inputs[2])
tensor_array_scatter_func = prelude.get_var('tensor_array_scatter', dtype_str)
return tensor_array_scatter_func(inputs[0], inputs[1], values)
return _impl

def _tensor_array_gather():
def _impl(inputs, attr, params, prelude):
return prelude.tensor_array_gather(inputs[2], inputs[1])
return _impl

def _tensor_array_size():
def _impl(inputs, attr, params, prelude):
return prelude.length(inputs[0])
return _impl

def _tensor_array_write():
def _impl(inputs, attr, params, prelude):
input_rank = len(inputs[2].type_annotation.shape)
dtype = attr.get('T').name

tensor_name = 'tensor{}'.format(input_rank)
tensor_func = prelude.get_var(tensor_name, dtype)
v = tensor_func(inputs[2])
write_func = prelude.get_var('tensor_array_write', dtype)

return write_func(inputs[3], _op.take(inputs[1], tvm.relay.const(0)), v)
return _impl

def _tensor_array_read():
def _impl(inputs, attr, params, prelude):
read_func = prelude.get_var('tensor_array_read', attr.get('dtype').name)
return read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0)))
return _impl

def _tensor_array_split():
def _impl(inputs, attr, params, prelude):
input_rank = len(inputs[1].type_annotation.shape)
dtype_str = attr.get('T').name
v = prelude.get_var("tensor{}".format(input_rank), dtype_str)
lengths = _op.cast(inputs[2], 'int32')
split_var = prelude.get_var('tensor_array_split', dtype_str)
return split_var(inputs[0], v, lengths)
return _impl

def _tensor_array_concat():
def _impl(inputs, attr, params, prelude):
concat_func = prelude.get_var('tensor_array_concat', attr['dtype'].name)
return concat_func(inputs[1])
return _impl

def _tile():
def _impl(inputs, attr, params):
reps = _get_list_param(params, inputs.pop())
Expand Down Expand Up @@ -1313,6 +1380,14 @@ def _impl(inputs, attr, params):
'NotEqual' : _broadcast('not_equal'),
'OneHot' : _one_hot(),
'Pack' : _pack(),
'TensorArrayV3' : _tensor_array(),
'TensorArrayScatterV3' : _tensor_array_scatter(),
'TensorArrayGatherV3' : _tensor_array_gather(),
'TensorArraySizeV3' : _tensor_array_size(),
'TensorArrayWriteV3' : _tensor_array_write(),
'TensorArrayReadV3' : _tensor_array_read(),
'TensorArraySplitV3' : _tensor_array_split(),
'TensorArrayConcatV3' : _tensor_array_concat(),
'Pad' : _pad('Pad'),
'PadV2' : _pad('PadV2'),
'Pow' : _elemwise('power'),
Expand Down Expand Up @@ -1860,6 +1935,7 @@ def __init__(self):
self._loops = {}
self._branches = {}
self._mod = _module.Module({})
self._prelude = Prelude(self._mod)

def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct relay nodes from tensorflow graph definition - GraphDef.
Expand Down Expand Up @@ -2335,7 +2411,11 @@ def _convert_operator(self, op_name, inputs, attrs,
if op_name in identity_list:
sym = get_relay_op(op_name)(*inputs, **attrs)
elif op_name in convert_map:
sym = convert_map[op_name](inputs, attrs, self._params)
if 'TensorArray' in op_name:
sym = convert_map[op_name](inputs, attrs, self._params, self._prelude)
else:
sym = convert_map[op_name](inputs, attrs, self._params)

elif op_name in convert_map_rnn:
sym = self._convert_rnn_operator(op_name, inputs, attrs,
self._params, graph,
Expand Down
26 changes: 26 additions & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,29 @@ def clip_compute(attrs, inputs, output_type, target):

register_schedule("clip", schedule_elemwise)

@script
def _cast_shape_function(x):
out_ndim = len(x)
out = output_tensor((out_ndim,), "int64")
for i in const_range(out_ndim):
out[i] = x[i]
icemelon marked this conversation as resolved.
Show resolved Hide resolved
return out

def cast_shape_func(attrs, inputs, out_ndims):
return [_cast_shape_function(*inputs)]

@script
def _expand_dims_shape_func(x):
ndim = len(x.shape)
out = output_tensor((ndim+1,), "int64")
out[0] = int64(1)
for i in const_range(0, ndim):
out[i+1] = int64(x.shape[i])
return out

def expand_dims_shape_func(attrs, inputs, out_ndims):
return [_expand_dims_shape_func(*inputs)]

# shape func
@script
def _broadcast_shape_func(x, y, ndim):
Expand Down Expand Up @@ -140,6 +163,9 @@ def _broadcast_shape_func(x, y, ndim):
def broadcast_shape_func(attrs, inputs, out_ndims):
return [_broadcast_shape_func(*inputs, out_ndims[0])]

register_shape_func("expand_dims", False, expand_dims_shape_func)
register_shape_func("cast", False, cast_shape_func)

register_shape_func("add", False, broadcast_shape_func)
register_shape_func("subtract", False, broadcast_shape_func)
register_shape_func("multiply", False, broadcast_shape_func)
Expand Down
Loading