Skip to content

Commit

Permalink
[Relay][Frontend][TF] Add tensor array ops
Browse files Browse the repository at this point in the history
  • Loading branch information
wweic committed Aug 22, 2019
1 parent d4b66da commit 42f78bb
Show file tree
Hide file tree
Showing 8 changed files with 813 additions and 28 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/adt.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def __init__(self, lhs, rhs):
class Match(Expr):
"""Pattern matching expression in Relay."""

def __init__(self, data, clauses, complete=True):
def __init__(self, data, clauses, complete=False):
"""Construct a Match.
Parameters
Expand Down
85 changes: 84 additions & 1 deletion python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@
# Numpy support
import numpy as np

import pdb

import tvm

from tvm.relay.prelude import Prelude
from topi.util import get_const_tuple

from .. import analysis
from .. import expr as _expr
from .. import op as _op
Expand Down Expand Up @@ -506,6 +512,69 @@ def _impl(inputs, attr, params):
return _op.concatenate(inputs_reshaped, axis)
return _impl

def _tensor_array():
def _impl(inputs, attr, params, prelude):
return prelude.tensor_array(_op.take(inputs[0], tvm.relay.const(0)))
return _impl

def _tensor_array_scatter():
def _impl(inputs, attr, params, prelude):
values = None
import pdb
# pdb.set_trace()
if len(inputs[2].type_annotation.shape) == 1:
pass
elif len(inputs[2].type_annotation.shape) == 2:
values = prelude.tensor_array_unstack_tensor2(inputs[2])

return prelude.tensor_array_scatter(inputs[0], inputs[1], values)
return _impl

def _tensor_array_gather():
def _impl(inputs, attr, params, prelude):
return prelude.tensor_array_gather(inputs[2], inputs[1])
return _impl

def _tensor_array_size():
def _impl(inputs, attr, params, prelude):
return prelude.tensor_array_size(inputs[0])
return _impl

def _tensor_array_write():
def _impl(inputs, attr, params, prelude):
import pdb
# pdb.set_trace()
if len(inputs[2].type_annotation.shape) == 2:
v = prelude.tensor2(inputs[2])
elif len(inputs[2].type_annotation.shape) == 3:
v = prelude.tensor3(inputs[2])
return prelude.tensor_array_write(inputs[3], _op.take(inputs[1], tvm.relay.const(0)), v)
return _impl

def _tensor_array_read():
def _impl(inputs, attr, params, prelude):
import pdb
# pdb.set_trace()
return prelude.tensor_array_read(inputs[2], _op.take(inputs[1], tvm.relay.const(0)))
return _impl

def _tensor_array_split():
def _impl(inputs, attr, params, prelude):
import pdb
if len(inputs[1].type_annotation.shape) == 2:
v = prelude.tensor2(inputs[1])
elif len(inputs[1].type_annotation.shape) == 3:
v = prelude.tensor3(inputs[1])
# pdb.set_trace()
lengths = _op.cast(inputs[2], 'int32')
return prelude.tensor_array_split(inputs[0], v, lengths)
return _impl

def _tensor_array_concat():
def _impl(inputs, attr, params, prelude):
return prelude.tensor_array_concat(inputs[1])
return _impl

def _tile():
def _impl(inputs, attr, params):
reps = params[inputs.pop().name_hint].asnumpy()
Expand Down Expand Up @@ -968,6 +1037,7 @@ def _impl(inputs, attr, params):

def _range():
def _impl(inputs, attr, params):
pdb.set_trace()
start = params.pop(inputs[0].name_hint).asnumpy()[0]
limit = params.pop(inputs[1].name_hint).asnumpy()[0] \
if hasattr(inputs[1], "name_hint") else params.pop('Rank').asnumpy()[0]
Expand Down Expand Up @@ -1285,6 +1355,14 @@ def _impl(inputs, attr, params):
'Neg' : AttrCvt('negative'),
'NotEqual' : _broadcast('not_equal'),
'Pack' : _pack(),
'TensorArrayV3' : _tensor_array(),
'TensorArrayScatterV3' : _tensor_array_scatter(),
'TensorArrayGatherV3' : _tensor_array_gather(),
'TensorArraySizeV3' : _tensor_array_size(),
'TensorArrayWriteV3' : _tensor_array_write(),
'TensorArrayReadV3' : _tensor_array_read(),
'TensorArraySplitV3' : _tensor_array_split(),
'TensorArrayConcatV3' : _tensor_array_concat(),
'Pad' : _pad('Pad'),
'PadV2' : _pad('PadV2'),
'Pow' : _elemwise('power'),
Expand Down Expand Up @@ -1830,6 +1908,7 @@ def __init__(self):
self._loops = {}
self._branches = {}
self._mod = _module.Module({})
self._prelude = Prelude(self._mod)

def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct relay nodes from tensorflow graph definition - GraphDef.
Expand Down Expand Up @@ -2306,7 +2385,11 @@ def _convert_operator(self, op_name, inputs, attrs,
if op_name in identity_list:
sym = get_relay_op(op_name)(*inputs, **attrs)
elif op_name in convert_map:
sym = convert_map[op_name](inputs, attrs, self._params)
if 'TensorArray' in op_name:
sym = convert_map[op_name](inputs, attrs, self._params, self._prelude)
else:
sym = convert_map[op_name](inputs, attrs, self._params)

elif op_name in convert_map_rnn:
sym = self._convert_rnn_operator(op_name, inputs, attrs,
self._params, graph,
Expand Down
74 changes: 74 additions & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import topi
from .op import register_compute, register_schedule, register_pattern
from .op import schedule_injective, OpPattern
from ...hybrid import script
from ...api import convert

schedule_broadcast = schedule_injective
schedule_elemwise = schedule_injective
Expand Down Expand Up @@ -104,3 +106,75 @@ def clip_compute(attrs, inputs, output_type, target):
return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)]

register_schedule("clip", schedule_elemwise)

@script
def _cast_shape_function(x):
out_ndim = len(x)
out = output_tensor((out_ndim,), "int64")
for i in const_range(out_ndim):
out[i] = x[i]
return out

def cast_shape_func(attrs, inputs, out_ndims):
return [_cast_shape_function(*inputs)]

@script
def _expand_dims_shape_func(x):
ndim = len(x.shape)
out = output_tensor((ndim+1,), "int64")
out[0] = int64(1)
for i in const_range(0, ndim):
out[i+1] = int64(x.shape[i])
return out

def expand_dims_shape_func(attrs, inputs, out_ndims):
return [_expand_dims_shape_func(*inputs)]

# shape func
@script
def _broadcast_shape_func(x, y, ndim):
out = output_tensor((ndim,), "int64")
if len(x.shape) == 0:
for i in const_range(ndim):
out[i] = y[i]
elif len(y.shape) == 0:
for i in const_range(ndim):
out[i] = x[i]
else:
ndim1 = x.shape[0]
ndim2 = y.shape[0]
for i in const_range(1, min(ndim1, ndim2)+1):
if x[ndim1-i] == y[ndim2-i]:
out[ndim-i] = x[ndim1-i]
elif x[ndim1-i] == 1:
out[ndim-i] = y[ndim2-i]
else:
assert y[ndim2 - i] == 1, "Incompatible broadcast type %s and %s" % (
x[ndim1-i], y[ndim2-i])
out[ndim-i] = x[ndim1-i]
for i in const_range(min(ndim1, ndim2)+1, ndim+1):
if ndim1 >= ndim2:
out[ndim-i] = x[ndim1-i]
else:
out[ndim-i] = y[ndim2-i]
return out

def broadcast_shape_func(attrs, inputs, out_ndims):
return [_broadcast_shape_func(*inputs, out_ndims[0])]

register_shape_func("expand_dims", False, expand_dims_shape_func)
register_shape_func("cast", False, cast_shape_func)

register_shape_func("add", False, broadcast_shape_func)
register_shape_func("subtract", False, broadcast_shape_func)
register_shape_func("multiply", False, broadcast_shape_func)
register_shape_func("divide", False, broadcast_shape_func)
register_shape_func("mod", False, broadcast_shape_func)
register_shape_func("logical_and", False, broadcast_shape_func)
register_shape_func("logical_or", False, broadcast_shape_func)
register_shape_func("equal", False, broadcast_shape_func)
register_shape_func("not_equal", False, broadcast_shape_func)
register_shape_func("less", False, broadcast_shape_func)
register_shape_func("less_equal", False, broadcast_shape_func)
register_shape_func("greater", False, broadcast_shape_func)
register_shape_func("greater_equal", False, broadcast_shape_func)
Loading

0 comments on commit 42f78bb

Please sign in to comment.