diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 8be3d221d42d..635a600e7157 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -156,7 +156,26 @@ def _convert_advanced_activation(inexpr, keras_layer, etab): def _convert_merge(inexpr, keras_layer, _): merge_type = type(keras_layer).__name__ ret = inexpr[0] - if merge_type == 'Subtract': + if merge_type == 'Dot': + axes = keras_layer.axes + if isinstance(keras_layer.axes, int): + axes = [keras_layer.axes, keras_layer.axes] + if isinstance(axes, list): + if len(axes) != 2: + raise tvm.error.OpAttributeUnimplemented( + 'Dot with axes {} is not supported.'.format(keras_layer.axes)) + for i, axis in enumerate(axes): + if axis not in [1, 2]: + raise tvm.error.OpAttributeUnimplemented( + 'Dot with axes {} is not supported.'.format(keras_layer.axes)) + if axes[i] == 2: + inexpr[i] = _op.transpose(inexpr[i], axes=[0, 2, 1]) + else: + raise tvm.error.OpAttributeUnImplemented( + 'Dot with axes {} is not supported.'.format(keras_layer.axes)) + ret_dot = _op.nn.batch_matmul(inexpr[0], inexpr[1]) + ret = _op.transpose(ret_dot, axes=[0, 2, 1]) + elif merge_type == 'Subtract': assert len(inexpr) == 2, "Subtract merge takes 2 inputs." ret = _op.subtract(ret, inexpr[1]) elif merge_type in ['Add', 'Multiply', 'Maximum']: @@ -635,7 +654,7 @@ def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument 'Average' : _convert_merge, 'Maximum' : _convert_merge, - # 'Dot' : _convert_merge, + 'Dot' : _convert_merge, 'Permute' : _convert_permute, # 'Embedding' : _convert_embedding, # 'RepeatVector' : _convert_repeat_vector, diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index 9996bb77f168..4b71cb6f9a27 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -84,13 +84,26 @@ def test_forward_merge(): keras.layers.Average(), keras.layers.Concatenate()] for merge_func in merge_funcs: - if isinstance(merge_func, keras.layers.merge.Subtract): + if isinstance(merge_func, (keras.layers.merge.Subtract, keras.layers.merge.Dot)): out = merge_func([x, y]) else: out = merge_func([x, y, z]) keras_model = keras.models.Model(data, out) verify_keras_frontend(keras_model) +def test_forward_merge_dot(): + data1 = keras.layers.Input(shape=(2, 2)) + data2 = keras.layers.Input(shape=(2, 2)) + merge_funcs = [keras.layers.Dot(axes=[1, 2]), + keras.layers.Dot(axes=[2, 1]), + keras.layers.Dot(axes=[1, 1]), + keras.layers.Dot(axes=[2, 2]), + keras.layers.Dot(axes=1), + keras.layers.Dot(axes=2)] + for merge_func in merge_funcs: + out = merge_func([data1, data2]) + keras_model = keras.models.Model([data1, data2], out) + verify_keras_frontend(keras_model) def test_forward_activations(): data = keras.layers.Input(shape=(32, 32, 3)) @@ -281,6 +294,7 @@ def test_forward_mobilenet(): if __name__ == '__main__': test_forward_merge() + test_forward_merge_dot() test_forward_activations() test_forward_dense() test_forward_permute()