Skip to content

Commit

Permalink
fix reshape
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww committed Aug 17, 2019
1 parent 5f94292 commit 2fc9764
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,19 +148,23 @@ def _convert_merge(inexpr, keras_layer, _):
merge_type = type(keras_layer).__name__
ret = inexpr[0]
if merge_type == 'Dot':
if keras_layer.axes == 1:
ret = _op.transpose(ret, axes=[0, 2, 1])
inexpr[1] = _op.transpose(inexpr[1], axes=[0, 2, 1])

elif isinstance(keras_layer.axes, list):
if keras_layer.axes == [1, 2]:
ret = _op.transpose(ret, axes=[0, 2, 1])
elif keras_layer.axes == [2, 1]:
inexpr[1] = _op.transpose(inexpr[1], axes=[0, 2, 1])
else:
axes = keras_layer.axes
if isinstance(keras_layer.axes, int):
axes = [keras_layer.axes, keras_layer.axes]
if isinstance(axes, list):
if len(axes) != 2:
raise tvm.error.OpAttributeUnimplemented(
'Dot with {} is not supported.'.format(keras_layer.axes))
ret = _op.nn.batch_matmul(ret, inexpr[1])
'Dot with axes {} is not supported.'.format(keras_layer.axes))
for i, axis in enumerate(axes):
if axis not in [1, 2]:
raise tvm.error.OpAttributeUnimplemented(
'Dot with axes {} is not supported.'.format(keras_layer.axes))
if axes[i] == 1:
inexpr[i] = _op.transpose(inexpr[i], axes=[0, 2, 1])
else:
raise tvm.error.OpAttributeUnImplemented(
'Dot with axes {} is not supported.'.format(keras_layer.axes))
ret = _op.nn.batch_matmul(inexpr[0], inexpr[1])
elif merge_type == 'Subtract':
assert len(inexpr) == 2, "Subtract merge takes 2 inputs."
ret = _op.subtract(ret, inexpr[1])
Expand Down Expand Up @@ -475,6 +479,8 @@ def _convert_concat(inexpr, keras_layer, _):

def _convert_reshape(inexpr, keras_layer, _):
_check_data_format(keras_layer)
if len(keras_layer.target_shape) < 3:
return _op.reshape(inexpr, newshape=(1, ) + keras_layer.target_shape)
ch = keras_layer.input_shape[-1]
assert ch == keras_layer.target_shape[-1], \
"Only supports last dimension in target shape being equal to " \
Expand Down

0 comments on commit 2fc9764

Please sign in to comment.