Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Frontend][Keras] NHWC import support. #4899

Merged
merged 5 commits into from
Feb 18, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 113 additions & 38 deletions python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def _convert_merge(inexpr, keras_layer, _):
assert len(inexpr) == 2, "Subtract merge takes 2 inputs."
ret = _op.subtract(ret, inexpr[1])
elif merge_type in ['Add', 'Multiply', 'Maximum']:
op_map = {'Add':_op.add, 'Multiply':_op.multiply, 'Maximum':_op.maximum}
op_map = {'Add': _op.add, 'Multiply': _op.multiply, 'Maximum': _op.maximum}
for i in range(1, len(inexpr)):
ret = op_map[merge_type](ret, inexpr[i])
elif merge_type == 'Average':
Expand All @@ -206,7 +206,7 @@ def _convert_permute(inexpr, keras_layer, _):
def _convert_dense(inexpr, keras_layer, etab):
weightList = keras_layer.get_weights()
weight = etab.new_const(weightList[0].transpose([1, 0]))
params = {'weight':weight, 'units':weightList[0].shape[1]}
params = {'weight': weight, 'units': weightList[0].shape[1]}
input_shape = keras_layer.input_shape
input_dim = len(input_shape)
# In case of RNN dense, input shape will be (1, 1, n)
Expand Down Expand Up @@ -234,18 +234,29 @@ def _convert_dense(inexpr, keras_layer, etab):

def _convert_convolution(inexpr, keras_layer, etab):
_check_data_format(keras_layer)
if etab.data_layout == 'NHWC':
kernel_layout = 'HWIO'
else:
kernel_layout = 'OIHW'
is_deconv = type(keras_layer).__name__ == 'Conv2DTranspose'
is_depthconv = type(keras_layer).__name__ == 'DepthwiseConv2D'
weightList = keras_layer.get_weights()
weight = weightList[0]
if is_deconv:
kernel_h, kernel_w, n_filters, in_channels = weightList[0].shape
weight = weightList[0].transpose([3, 2, 0, 1])
kernel_h, kernel_w, n_filters, in_channels = weight.shape
if kernel_layout == 'OIHW':
weight = weight.transpose([3, 2, 0, 1])
elif is_depthconv:
kernel_h, kernel_w, in_channels, depth_mult = weightList[0].shape
weight = weightList[0].transpose([2, 3, 0, 1])
kernel_h, kernel_w, in_channels, depth_mult = weight.shape
if kernel_layout == 'OIHW':
weight = weight.transpose([2, 3, 0, 1])
else:
kernel_layout = "HWOI"
jwfromm marked this conversation as resolved.
Show resolved Hide resolved
elif etab.data_layout == 'NCHW':
kernel_h, kernel_w, in_channels, n_filters = weight.shape
weight = weight.transpose([3, 2, 0, 1])
else:
kernel_h, kernel_w, in_channels, n_filters = weightList[0].shape
weight = weightList[0].transpose([3, 2, 0, 1])
kernel_h, kernel_w, in_channels, n_filters = weight.shape
if isinstance(keras_layer.dilation_rate, (list, tuple)):
dilation = [keras_layer.dilation_rate[0], keras_layer.dilation_rate[1]]
else:
Expand All @@ -257,7 +268,9 @@ def _convert_convolution(inexpr, keras_layer, etab):
'kernel_size': [kernel_h, kernel_w],
'strides': [stride_h, stride_w],
'dilation': dilation,
'padding': [0, 0]}
'padding': [0, 0],
'data_layout': etab.data_layout,
'kernel_layout': kernel_layout}
if is_depthconv:
params['channels'] = in_channels * depth_mult
params['groups'] = in_channels
Expand All @@ -274,8 +287,13 @@ def _convert_convolution(inexpr, keras_layer, etab):
if pad_t == pad_b and pad_l == pad_r:
params['padding'] = (pad_t, pad_l)
else:
inexpr = _op.nn.pad(data=inexpr, pad_width=(
(0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r)))
if etab.data_layout == 'NCHW':
jwfromm marked this conversation as resolved.
Show resolved Hide resolved
inexpr = _op.nn.pad(data=inexpr, pad_width=(
(0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r)))
else:
inexpr = _op.nn.pad(data=inexpr, pad_width=(
(0, 0), (pad_t, pad_b), (pad_l, pad_r), (0, 0)))

else:
msg = 'Padding with {} is not supported for operator Convolution ' \
'in frontend Keras.'
Expand All @@ -284,9 +302,13 @@ def _convert_convolution(inexpr, keras_layer, etab):
out = _op.nn.conv2d_transpose(data=inexpr, **params)
else:
out = _op.nn.conv2d(data=inexpr, **params)

if keras_layer.use_bias:
bias = etab.new_const(weightList[1])
out = _op.nn.bias_add(out, bias)
if etab.data_layout == 'NCHW':
out = _op.nn.bias_add(out, bias)
else:
out = _op.nn.bias_add(out, bias, axis=-1)
# defuse activation
if sys.version_info.major < 3:
act_type = keras_layer.activation.func_name
Expand All @@ -299,18 +321,27 @@ def _convert_convolution(inexpr, keras_layer, etab):

def _convert_separable_convolution(inexpr, keras_layer, etab):
_check_data_format(keras_layer)
if etab.data_layout == 'NHWC':
kernel_layout = 'HWOI'
else:
kernel_layout = 'OIHW'
weightList = keras_layer.get_weights()
# depthwise conv
kernel_h, kernel_w, in_channels, depth_mult = weightList[0].shape
stride_h, stride_w = keras_layer.strides
weight0 = weightList[0].transpose([2, 3, 0, 1])
if kernel_layout == 'OIHW':
weight0 = weightList[0].transpose([2, 3, 0, 1])
else:
weight0 = weightList[0]
params0 = {'weight': etab.new_const(weight0),
'channels': in_channels * depth_mult,
'groups': in_channels,
'kernel_size': [kernel_h, kernel_w],
'strides': [stride_h, stride_w],
'dilation': [1, 1],
'padding': [0, 0]}
'padding': [0, 0],
'data_layout': etab.data_layout,
'kernel_layout': kernel_layout}
if keras_layer.padding == 'valid':
pass
# we insert a separate pad operator
Expand All @@ -322,26 +353,39 @@ def _convert_separable_convolution(inexpr, keras_layer, etab):
if pad_t == pad_b and pad_l == pad_r:
params0['padding'] = (pad_t, pad_l)
else:
inexpr = _op.nn.pad(data=inexpr, pad_width=(
(0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r)))
if etab.data_layout == 'NCHW':
jwfromm marked this conversation as resolved.
Show resolved Hide resolved
inexpr = _op.nn.pad(data=inexpr, pad_width=(
(0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r)))
else:
inexpr = _op.nn.pad(data=inexpr, pad_width=(
(0, 0), (pad_t, pad_b), (pad_l, pad_r), (0, 0)))

else:
msg = 'Padding with {} is not supported for operator Separable ' \
'Convolution in frontend Keras.'
raise tvm.error.OpAttributeUnImplemented(msg.format(keras_layer.padding))

depthconv = _op.nn.conv2d(data=inexpr, **params0)
# pointwise conv
weight1 = weightList[1].transpose([3, 2, 0, 1])
if kernel_layout == 'OIHW':
weight1 = weightList[1].transpose([3, 2, 0, 1])
else:
weight1 = weightList[1]
kernel_layout = "HWIO"
params1 = {'weight': etab.new_const(weight1),
'channels': weight1.shape[0],
'channels': weightList[1].shape[3],
'groups': 1,
'kernel_size': [1, 1],
'strides': [1, 1],
'dilation': [1, 1]}
'dilation': [1, 1],
'data_layout': etab.data_layout,
'kernel_layout': kernel_layout}
out = _op.nn.conv2d(data=depthconv, **params1)
if keras_layer.use_bias:
bias = etab.new_const(weightList[2])
out = _op.nn.bias_add(out, bias)
if etab.data_layout == 'NCHW':
out = _op.nn.bias_add(out, bias)
else:
out = _op.nn.bias_add(out, bias, axis=-1)
# defuse activation
if sys.version_info.major < 3:
act_type = keras_layer.activation.func_name
Expand All @@ -352,26 +396,31 @@ def _convert_separable_convolution(inexpr, keras_layer, etab):
return out


def _convert_flatten(inexpr, keras_layer, _):
def _convert_flatten(inexpr, keras_layer, etab):
_check_data_format(keras_layer)
# NCHW -> NHWC so that dense can be correctly converted
inexpr = _op.transpose(inexpr, axes=[0, 2, 3, 1])
if etab.data_layout == 'NCHW':
inexpr = _op.transpose(inexpr, axes=[0, 2, 3, 1])
return _op.nn.batch_flatten(inexpr)


def _convert_pooling(inexpr, keras_layer, etab):
_check_data_format(keras_layer)
pool_type = type(keras_layer).__name__
# global pool in keras = global pool + flatten in relay
global_pool_params = {'layout': etab.data_layout}
if pool_type == 'GlobalMaxPooling2D':
return _convert_flatten(_op.nn.global_max_pool2d(inexpr), keras_layer, etab)
return _convert_flatten(
_op.nn.global_max_pool2d(inexpr, **global_pool_params), keras_layer, etab)
if pool_type == 'GlobalAveragePooling2D':
return _convert_flatten(_op.nn.global_avg_pool2d(inexpr), keras_layer, etab)
return _convert_flatten(
_op.nn.global_avg_pool2d(inexpr, **global_pool_params), keras_layer, etab)
pool_h, pool_w = keras_layer.pool_size
stride_h, stride_w = keras_layer.strides
params = {'pool_size': [pool_h, pool_w],
'strides': [stride_h, stride_w],
'padding': [0, 0]}
'padding': [0, 0],
'layout': etab.data_layout}
if keras_layer.padding == 'valid':
pass
elif keras_layer.padding == 'same':
Expand All @@ -392,7 +441,7 @@ def _convert_pooling(inexpr, keras_layer, etab):
'Operator {} is not supported for frontend Keras.'.format(keras_layer))


def _convert_upsample(inexpr, keras_layer, _):
def _convert_upsample(inexpr, keras_layer, etab):
_check_data_format(keras_layer)
upsample_type = type(keras_layer).__name__
params = {}
Expand Down Expand Up @@ -424,7 +473,9 @@ def _convert_upsample(inexpr, keras_layer, _):
else:
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend Keras.'.format(upsample_type))
return _op.nn.upsampling(inexpr, **params)
params['layout'] = etab.data_layout
out = _op.nn.upsampling(inexpr, **params)
return out


def _convert_cropping(inexpr, keras_layer, _):
Expand All @@ -442,9 +493,15 @@ def _convert_cropping(inexpr, keras_layer, _):


def _convert_batchnorm(inexpr, keras_layer, etab):
if etab.data_layout == 'NCHW' or len(keras_layer.input_shape) < 4:
axis = 1
else:
axis = 3

params = {'scale': False,
'center': False,
'epsilon': keras_layer.epsilon}
'epsilon': keras_layer.epsilon,
'axis': axis}
idx = 0
if keras_layer.scale:
params['scale'] = True
Expand All @@ -469,7 +526,7 @@ def _convert_batchnorm(inexpr, keras_layer, etab):
return result


def _convert_padding(inexpr, keras_layer, _):
def _convert_padding(inexpr, keras_layer, etab):
_check_data_format(keras_layer)
padding_type = type(keras_layer).__name__
padding = keras_layer.padding
Expand All @@ -495,16 +552,21 @@ def _convert_padding(inexpr, keras_layer, _):
else:
msg = 'Operator {} is not supported in frontend Keras.'
raise tvm.error.OpNotImplemented(msg.format(padding_type))
return _op.nn.pad(data=inexpr,
pad_width=((0, 0), (0, 0), (top, bottom), (left, right)))
if etab.data_layout == 'NCHW':
return _op.nn.pad(data=inexpr, pad_width=((0, 0), (0, 0), (top, bottom), (left, right)))
return _op.nn.pad(data=inexpr, pad_width=((0, 0), (top, bottom), (left, right), (0, 0)))


def _convert_concat(inexpr, keras_layer, _):
def _convert_concat(inexpr, keras_layer, etab):
_check_data_format(keras_layer)
return _op.concatenate(_as_list(inexpr), axis=1)
if etab.data_layout == 'NHWC' or len(keras_layer.input_shape[0]) < 4:
axis = -1
else:
axis = 1
return _op.concatenate(_as_list(inexpr), axis=axis)


def _convert_reshape(inexpr, keras_layer, _):
def _convert_reshape(inexpr, keras_layer, etab):
_check_data_format(keras_layer)
inshape = keras_layer.input_shape # includes batch
tshape = keras_layer.target_shape # no batch
Expand All @@ -525,7 +587,10 @@ def _convert_reshape(inexpr, keras_layer, _):
assert ch == tshape[-1], \
"Only supports last dimension in target shape being equal to " \
"the channel number of input tensor."
shape = (-1, ch) + tshape[:-1]
if etab.data_layout == 'NCHW':
shape = (-1, ch) + tshape[:-1]
else:
shape = (-1,) + tshape[:-1] + (ch,)
return _op.reshape(inexpr, newshape=shape)


Expand Down Expand Up @@ -740,7 +805,7 @@ def keras_op_to_relay(inexpr, keras_layer, outname, etab):
etab.set_expr(name, out)


def from_keras(model, shape=None):
def from_keras(model, shape=None, layout='NCHW'):
"""Convert keras model to relay Function.

Parameters
Expand All @@ -751,6 +816,10 @@ def from_keras(model, shape=None):
shape: dict of str to int list/tuple
Input shapes of the model, optional

layout: str
jwfromm marked this conversation as resolved.
Show resolved Hide resolved
One of 'NCHW' or 'NHWC', indicates how data should be arranged in
the output model.

Returns
-------
mod : tvm.IRModule
Expand Down Expand Up @@ -793,6 +862,9 @@ def _convert_input_layer(keras_layer):
assert isinstance(model, expected_model_class)

etab = ExprTable()
# Set global data format.
assert layout in ['NCHW', 'NHWC'], "Layout must be one of 'NCHW' or NHWC"
etab.data_layout = layout
for keras_layer in model.layers:
if isinstance(keras_layer, input_layer_class):
_convert_input_layer(keras_layer)
Expand All @@ -818,7 +890,10 @@ def _convert_input_layer(keras_layer):
# The one exception is InputLayer. Changing input variable names after conversion
# would confuse users, so we should keep them as far as possible. Fortunately,
# they are named uniquely to input_1, input_2, input_3... by default.
zip_node = zip(node.node_indices, node.tensor_indices, node.inbound_layers)
zip_node = zip(
_as_list(node.node_indices),
_as_list(node.tensor_indices),
_as_list(node.inbound_layers))
for n_idx, t_idx, inbound_layer in zip_node:
if isinstance(inbound_layer, input_layer_class):
expr_name = inbound_layer.name
Expand Down
Loading