Skip to content

Commit

Permalink
[NNVM][TENSORFLOW] Some cleanup by combining depthwise with convolution.
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 committed Aug 3, 2018
1 parent 7ee9cca commit 6d68a18
Showing 1 changed file with 21 additions and 83 deletions.
104 changes: 21 additions & 83 deletions nnvm/python/nnvm/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,81 +168,7 @@ def _impl(inputs, attr, params):
custom_check=_dimension_constraint())(inputs, attr)
return _impl

def _conv():
def _impl(inputs, attr, params):
attr['data_format'] = attr['data_format'].decode("utf-8")

# Extract kernel shape from params
conv_param_weights = params[inputs[1].list_output_names()[0]]

if attr['data_format'] == 'NHWC':
attr['kernel_shape'] = (conv_param_weights.shape[0], conv_param_weights.shape[1])
attr['channels'] = conv_param_weights.shape[3]
if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][0], attr['dilations'][1])
elif attr['data_format'] == 'NCHW':
attr['kernel_shape'] = (conv_param_weights.shape[2], conv_param_weights.shape[3])
attr['channels'] = conv_param_weights.shape[1]
if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][2], attr['dilations'][3])
else:
raise TypeError("Unsupported data format type : {}".format(attr['data_format']))

# Fix strides
attr['strides'] = (attr['strides'][1], attr['strides'][2])

# Fix padding
input_shapes = attr['_input_shapes'][inputs[0]]
attr['padding'] = attr['padding'].decode("utf-8")

if attr['padding'] == 'VALID':
attr['padding'] = [0, 0]
elif attr['padding'] == 'SAME':
stride_h, stride_w = attr['strides']
kernel_h, kernel_w = attr['kernel_shape']
if attr['data_format'] == 'NHWC':
in_h = input_shapes[0][1]
in_w = input_shapes[0][2]
else:
in_h = input_shapes[0][2]
in_w = input_shapes[0][3]

pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, kernel_w, stride_w)

if attr['data_format'] == 'NHWC':
inputs[0] = _sym.pad(data=inputs[0],
pad_width=((0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1]),
(0, 0)))
else:
inputs[0] = _sym.pad(data=inputs[0],
pad_width=((0, 0),
(0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1])))

attr['padding'] = [0, 0]

else:
raise TypeError("Unsupported padding type : {}".format(attr['padding']))

if 'kernel_layout' not in attr:
attr['kernel_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW'

return AttrCvt(
op_name=_dimension_picker('conv'),
transforms={
'kernel_shape': 'kernel_size',
'data_format': 'layout',
'dilations': ('dilation', (0, 0)),
'group': ('groups', 1)},
extras={'use_bias': len(inputs) == 3},
custom_check=_dimension_constraint())(inputs, attr)
return _impl

def _depthwise_conv():
def _conv(opname):
def _impl(inputs, attr, params):
attr['data_format'] = attr['data_format'].decode("utf-8")
input_shapes = attr['_input_shapes'][inputs[0]]
Expand All @@ -253,24 +179,33 @@ def _impl(inputs, attr, params):
if attr['data_format'] == 'NHWC':
kernel_h, kernel_w, _, depth_mult = conv_param_weights.shape
attr['kernel_shape'] = (conv_param_weights.shape[0], conv_param_weights.shape[1])
attr['channels'] = input_shapes[0][3] * depth_mult
if opname == 'conv':
attr['channels'] = conv_param_weights.shape[3]
else:
attr['channels'] = input_shapes[0][3] * depth_mult

if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][0], attr['dilations'][1])
elif attr['data_format'] == 'NCHW':
depth_mult, _, kernel_h, kernel_w = conv_param_weights.shape
attr['kernel_shape'] = (conv_param_weights.shape[2], conv_param_weights.shape[3])
attr['channels'] = input_shapes[0][1] * depth_mult
if opname == 'conv':
attr['channels'] = conv_param_weights.shape[1]
else:
attr['channels'] = input_shapes[0][1] * depth_mult

if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][2], attr['dilations'][3])
else:
raise TypeError("Unsupported data format type : {}".format(attr['data_format']))


if opname == 'depthwise':
attr['groups'] = attr['channels']

# Fix strides
attr['strides'] = (attr['strides'][1], attr['strides'][2])

# Fix groups
attr['groups'] = attr['channels']

# Fix padding
attr['padding'] = attr['padding'].decode("utf-8")

Expand Down Expand Up @@ -308,7 +243,10 @@ def _impl(inputs, attr, params):
raise TypeError("Unsupported padding type : {}".format(attr['padding']))

if 'kernel_layout' not in attr:
attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW'
if opname == 'conv':
attr['kernel_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW'
else:
attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW'

return AttrCvt(
op_name=_dimension_picker('conv'),
Expand Down Expand Up @@ -687,7 +625,7 @@ def _impl(inputs, in_state_c, in_state_h, attr, params):
'CheckNumerics' : _check_numerics(),
'Concat' : _concat(),
'ConcatV2' : _concatV2(),
'Conv2D' : _conv(),
'Conv2D' : _conv('conv'),
'DecodeJpeg' : _decode_image(),
'ExpandDims' : _expand_dims(),
'Identity' : _identity(),
Expand All @@ -704,7 +642,7 @@ def _impl(inputs, in_state_c, in_state_h, attr, params):
'Squeeze' : _squeeze(),
'FusedBatchNorm' : _fused_batch_norm(),
'Relu6' : _relu6(),
'DepthwiseConv2dNative' : _depthwise_conv(),
'DepthwiseConv2dNative' : _conv('depthwise'),
'Shape' : _shape(),
'Sigmoid' : AttrCvt('sigmoid'),
'Fill' : _fill(),
Expand Down

0 comments on commit 6d68a18

Please sign in to comment.