Skip to content

Commit

Permalink
[FRONTEND][TENSORFLOW]Add Split and realdiv op support
Browse files Browse the repository at this point in the history
* Add Split and realdiv op support

* Fix the pad calculation in the case of dilated convolution
  • Loading branch information
Rasterer committed Dec 6, 2018
1 parent 990521d commit 45f88e2
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 15 deletions.
50 changes: 35 additions & 15 deletions nnvm/python/nnvm/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _impl(inputs, attr, params):
attr['channels'] = input_shape[3] * depth_mult

if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][0], attr['dilations'][1])
attr['dilations'] = (attr['dilations'][1], attr['dilations'][2])
attr['strides'] = (attr['strides'][1], attr['strides'][2])
elif attr['data_format'] == 'NCHW':
depth_mult, _, kernel_h, kernel_w = weights_shape
Expand Down Expand Up @@ -252,8 +252,12 @@ def _impl(inputs, attr, params):
in_h = input_shape[2]
in_w = input_shape[3]

pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, kernel_w, stride_w)
dilation_h = attr['dilations'][0]
dilation_w = attr['dilations'][1]
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w)

if attr['data_format'] == 'NHWC':
inputs[0] = _sym.pad(data=inputs[0],
Expand Down Expand Up @@ -783,6 +787,15 @@ def _impl(inputs, attr, params):
)(inputs, attr)
return _impl

def _split():
def _impl(inputs, attr, params):
axis = params.pop(inputs[0].list_output_names()[0])
return AttrCvt(
op_name="split", ignores=['T'],
transforms={'num_split': 'indices_or_sections'},
extras={'axis': axis.asnumpy()[0]})(inputs[1], attr)
return _impl

# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -813,6 +826,7 @@ def _impl(inputs, attr, params):
'Add' : _elemwise('add'),
'Sub' : _elemwise('sub'),
'Mul' : _elemwise('mul'),
'RealDiv' : _elemwise('div'),
'Maximum' : _elemwise('max'),
'Minimum' : _elemwise('min'),
'Sum' : _sum(),
Expand Down Expand Up @@ -849,6 +863,7 @@ def _impl(inputs, attr, params):
'GreaterEqual' : _broadcast('greater_equal'),
'Equal' : _broadcast('equal'),
'NotEqual' : _broadcast('not_equal'),
'Split' : _split(),
}

# _convert_map_rnn defines maps of rnn operator name to
Expand Down Expand Up @@ -1144,21 +1159,26 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
# Pass the target layout
attr["_target_layout"] = layout

#ToDo: Some of the tensorflow operators internaly maintain
#execution layers and its output name will the layer number along with
#graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
#output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case,
#the digit has to be ignored.
if ":" in node.input[0]:
in_name, _ = node.input[0].split(':')
node.input[0] = in_name

# Fill shapes for all inputs in a list
inputs = []
for i in node.input:
if i in self._nodes:
inputs.append(self._nodes[i])
input_shapes[self._nodes[i]] = self._output_shapes[i]
#ToDo: Some of the tensorflow operators internaly maintain
#execution layers and its output name will the layer number along with
#graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
#output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case,
#the digit has to be ignored.
tensor_name = i.split(':')
node_name = tensor_name[0]
if node_name in self._nodes:
in_sym = self._nodes[node_name]
if len(in_sym.list_output_names()) > 1:
tensor_slot = int(tensor_name[1]) if len(tensor_name) > 1 else 0
in_sym = in_sym[tensor_slot]
input_shape = (self._output_shapes[node_name])[tensor_slot]
else:
input_shape = self._output_shapes[node_name][0]
inputs.append(in_sym)
input_shapes[in_sym] = [input_shape]
attr['_input_shapes'] = input_shapes

inputs = self._fix_extranodes(node.op, attr, inputs)
Expand Down
79 changes: 79 additions & 0 deletions nnvm/tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,83 @@ def test_forward_gather():
_test_gather((4,3,5,6), (1,4), [[2,1,0,0]], 0, 'float32')


#######################################################################
# Split
# -----

def _test_split(in_shape, axis, num_split, dtype):
""" One iteration of a Split """

with tf.Graph().as_default():
in_data = tf.placeholder(dtype, in_shape, name="in_data")
tf.split(in_data, num_split, axis)
np_data = np.random.uniform(size=in_shape).astype(dtype)
compare_tf_with_tvm(np_data, 'in_data:0', 'split:0')

def test_forward_split():
'''test split layer'''
# rank 1
_test_split((3,), 0, 1, 'float32')
_test_split((3,), 0, 3, 'float32')
_test_split((6,), 0, 3, 'float32')
# rank 2
_test_split((6, 2), 0, 3, 'float32')
_test_split((2, 6), 1, 3, 'float32')
# rank 3
_test_split((6, 2, 4), 0, 3, 'float32')
_test_split((2, 6, 4), 1, 3, 'float32')
_test_split((2, 4, 6), 2, 3, 'float32')
# rank 4
_test_split((6, 1, 3, 5), 0, 3, 'float32')
_test_split((1, 6, 3, 5), 1, 3, 'float32')
_test_split((1, 3, 6, 5), 2, 3, 'float32')
_test_split((1, 3, 5, 6), 3, 3, 'float32')
# split along negative axis
_test_split((6, 1, 3, 5), -4, 3, 'float32')
_test_split((1, 6, 3, 5), -3, 3, 'float32')
_test_split((1, 3, 6, 5), -2, 3, 'float32')
_test_split((1, 3, 5, 6), -1, 3, 'float32')


#######################################################################
# Split followed by concat
# ------------------------

def _test_split_concat(in_shape, axis, num_split, dtype):
""" One iteration of a split_concat pair"""

with tf.Graph().as_default():
in_data = tf.placeholder(dtype, in_shape, name="in_data")
splitted = tf.split(in_data, num_split, axis)
tf.concat(splitted, axis)
np_data = np.random.uniform(size=in_shape).astype(dtype)
compare_tf_with_tvm(np_data, 'in_data:0', 'concat:0')

def test_forward_split_concat():
'''test split followed by concat layers'''
# rank 1
_test_split_concat((3,), 0, 1, 'float32')
_test_split_concat((3,), 0, 3, 'float32')
_test_split_concat((6,), 0, 3, 'float32')
# rank 2
_test_split_concat((6, 2), 0, 3, 'float32')
_test_split_concat((2, 6), 1, 3, 'float32')
# rank 3
_test_split_concat((6, 2, 4), 0, 3, 'float32')
_test_split_concat((2, 6, 4), 1, 3, 'float32')
_test_split_concat((2, 4, 6), 2, 3, 'float32')
# rank 4
_test_split((6, 1, 3, 5), 0, 3, 'float32')
_test_split((1, 6, 3, 5), 1, 3, 'float32')
_test_split((1, 3, 6, 5), 2, 3, 'float32')
_test_split((1, 3, 5, 6), 3, 3, 'float32')
# split along negative axis
_test_split((6, 1, 3, 5), -4, 3, 'float32')
_test_split((1, 6, 3, 5), -3, 3, 'float32')
_test_split((1, 3, 6, 5), -2, 3, 'float32')
_test_split((1, 3, 5, 6), -1, 3, 'float32')


#######################################################################
# Multi Input to graph
# --------------------
Expand Down Expand Up @@ -1061,6 +1138,8 @@ def test_forward_rel_ops():
test_forward_pad()
test_forward_gather()
test_forward_stridedslice()
test_forward_split()
test_forward_split_concat()

# Activations
test_forward_sigmoid()
Expand Down

0 comments on commit 45f88e2

Please sign in to comment.