diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index 26e59dc7e830..c8db662152e9 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -215,7 +215,7 @@ def _impl(inputs, attr, params): attr['channels'] = input_shape[3] * depth_mult if 'dilations' in attr: - attr['dilations'] = (attr['dilations'][0], attr['dilations'][1]) + attr['dilations'] = (attr['dilations'][1], attr['dilations'][2]) attr['strides'] = (attr['strides'][1], attr['strides'][2]) elif attr['data_format'] == 'NCHW': depth_mult, _, kernel_h, kernel_w = weights_shape @@ -252,8 +252,12 @@ def _impl(inputs, attr, params): in_h = input_shape[2] in_w = input_shape[3] - pad_v = _get_pad_pair(in_h, kernel_h, stride_h) - pad_h = _get_pad_pair(in_w, kernel_w, stride_w) + dilation_h = attr['dilations'][0] + dilation_w = attr['dilations'][1] + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h) + pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w) if attr['data_format'] == 'NHWC': inputs[0] = _sym.pad(data=inputs[0], @@ -783,6 +787,15 @@ def _impl(inputs, attr, params): )(inputs, attr) return _impl +def _split(): + def _impl(inputs, attr, params): + axis = params.pop(inputs[0].list_output_names()[0]) + return AttrCvt( + op_name="split", ignores=['T'], + transforms={'num_split': 'indices_or_sections'}, + extras={'axis': axis.asnumpy()[0]})(inputs[1], attr) + return _impl + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -813,6 +826,7 @@ def _impl(inputs, attr, params): 'Add' : _elemwise('add'), 'Sub' : _elemwise('sub'), 'Mul' : _elemwise('mul'), + 'RealDiv' : _elemwise('div'), 'Maximum' : _elemwise('max'), 'Minimum' : _elemwise('min'), 'Sum' : _sum(), @@ -849,6 +863,7 @@ def _impl(inputs, attr, params): 'GreaterEqual' : _broadcast('greater_equal'), 'Equal' : _broadcast('equal'), 'NotEqual' : _broadcast('not_equal'), + 'Split' : _split(), } # _convert_map_rnn defines maps of rnn operator name to @@ -1144,21 +1159,26 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): # Pass the target layout attr["_target_layout"] = layout - #ToDo: Some of the tensorflow operators internaly maintain - #execution layers and its output name will the layer number along with - #graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the - #output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case, - #the digit has to be ignored. - if ":" in node.input[0]: - in_name, _ = node.input[0].split(':') - node.input[0] = in_name - # Fill shapes for all inputs in a list inputs = [] for i in node.input: - if i in self._nodes: - inputs.append(self._nodes[i]) - input_shapes[self._nodes[i]] = self._output_shapes[i] + #ToDo: Some of the tensorflow operators internaly maintain + #execution layers and its output name will the layer number along with + #graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the + #output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case, + #the digit has to be ignored. + tensor_name = i.split(':') + node_name = tensor_name[0] + if node_name in self._nodes: + in_sym = self._nodes[node_name] + if len(in_sym.list_output_names()) > 1: + tensor_slot = int(tensor_name[1]) if len(tensor_name) > 1 else 0 + in_sym = in_sym[tensor_slot] + input_shape = (self._output_shapes[node_name])[tensor_slot] + else: + input_shape = self._output_shapes[node_name][0] + inputs.append(in_sym) + input_shapes[in_sym] = [input_shape] attr['_input_shapes'] = input_shapes inputs = self._fix_extranodes(node.op, attr, inputs) diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py index c98748c0fc03..219ceb5bd379 100644 --- a/nnvm/tests/python/frontend/tensorflow/test_forward.py +++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py @@ -502,6 +502,83 @@ def test_forward_gather(): _test_gather((4,3,5,6), (1,4), [[2,1,0,0]], 0, 'float32') +####################################################################### +# Split +# ----- + +def _test_split(in_shape, axis, num_split, dtype): + """ One iteration of a Split """ + + with tf.Graph().as_default(): + in_data = tf.placeholder(dtype, in_shape, name="in_data") + tf.split(in_data, num_split, axis) + np_data = np.random.uniform(size=in_shape).astype(dtype) + compare_tf_with_tvm(np_data, 'in_data:0', 'split:0') + +def test_forward_split(): + '''test split layer''' + # rank 1 + _test_split((3,), 0, 1, 'float32') + _test_split((3,), 0, 3, 'float32') + _test_split((6,), 0, 3, 'float32') + # rank 2 + _test_split((6, 2), 0, 3, 'float32') + _test_split((2, 6), 1, 3, 'float32') + # rank 3 + _test_split((6, 2, 4), 0, 3, 'float32') + _test_split((2, 6, 4), 1, 3, 'float32') + _test_split((2, 4, 6), 2, 3, 'float32') + # rank 4 + _test_split((6, 1, 3, 5), 0, 3, 'float32') + _test_split((1, 6, 3, 5), 1, 3, 'float32') + _test_split((1, 3, 6, 5), 2, 3, 'float32') + _test_split((1, 3, 5, 6), 3, 3, 'float32') + # split along negative axis + _test_split((6, 1, 3, 5), -4, 3, 'float32') + _test_split((1, 6, 3, 5), -3, 3, 'float32') + _test_split((1, 3, 6, 5), -2, 3, 'float32') + _test_split((1, 3, 5, 6), -1, 3, 'float32') + + +####################################################################### +# Split followed by concat +# ------------------------ + +def _test_split_concat(in_shape, axis, num_split, dtype): + """ One iteration of a split_concat pair""" + + with tf.Graph().as_default(): + in_data = tf.placeholder(dtype, in_shape, name="in_data") + splitted = tf.split(in_data, num_split, axis) + tf.concat(splitted, axis) + np_data = np.random.uniform(size=in_shape).astype(dtype) + compare_tf_with_tvm(np_data, 'in_data:0', 'concat:0') + +def test_forward_split_concat(): + '''test split followed by concat layers''' + # rank 1 + _test_split_concat((3,), 0, 1, 'float32') + _test_split_concat((3,), 0, 3, 'float32') + _test_split_concat((6,), 0, 3, 'float32') + # rank 2 + _test_split_concat((6, 2), 0, 3, 'float32') + _test_split_concat((2, 6), 1, 3, 'float32') + # rank 3 + _test_split_concat((6, 2, 4), 0, 3, 'float32') + _test_split_concat((2, 6, 4), 1, 3, 'float32') + _test_split_concat((2, 4, 6), 2, 3, 'float32') + # rank 4 + _test_split((6, 1, 3, 5), 0, 3, 'float32') + _test_split((1, 6, 3, 5), 1, 3, 'float32') + _test_split((1, 3, 6, 5), 2, 3, 'float32') + _test_split((1, 3, 5, 6), 3, 3, 'float32') + # split along negative axis + _test_split((6, 1, 3, 5), -4, 3, 'float32') + _test_split((1, 6, 3, 5), -3, 3, 'float32') + _test_split((1, 3, 6, 5), -2, 3, 'float32') + _test_split((1, 3, 5, 6), -1, 3, 'float32') + + ####################################################################### # Multi Input to graph # -------------------- @@ -1061,6 +1138,8 @@ def test_forward_rel_ops(): test_forward_pad() test_forward_gather() test_forward_stridedslice() + test_forward_split() + test_forward_split_concat() # Activations test_forward_sigmoid()