Skip to content

Commit

Permalink
[Relay][Keras] Dot
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww committed Aug 16, 2019
1 parent b76b627 commit 5f94292
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
18 changes: 16 additions & 2 deletions python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,21 @@ def _convert_advanced_activation(inexpr, keras_layer, etab):
def _convert_merge(inexpr, keras_layer, _):
merge_type = type(keras_layer).__name__
ret = inexpr[0]
if merge_type == 'Subtract':
if merge_type == 'Dot':
if keras_layer.axes == 1:
ret = _op.transpose(ret, axes=[0, 2, 1])
inexpr[1] = _op.transpose(inexpr[1], axes=[0, 2, 1])

elif isinstance(keras_layer.axes, list):
if keras_layer.axes == [1, 2]:
ret = _op.transpose(ret, axes=[0, 2, 1])
elif keras_layer.axes == [2, 1]:
inexpr[1] = _op.transpose(inexpr[1], axes=[0, 2, 1])
else:
raise tvm.error.OpAttributeUnimplemented(
'Dot with {} is not supported.'.format(keras_layer.axes))
ret = _op.nn.batch_matmul(ret, inexpr[1])
elif merge_type == 'Subtract':
assert len(inexpr) == 2, "Subtract merge takes 2 inputs."
ret = _op.subtract(ret, inexpr[1])
elif merge_type in ['Add', 'Multiply', 'Maximum']:
Expand Down Expand Up @@ -625,7 +639,7 @@ def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument

'Average' : _convert_merge,
'Maximum' : _convert_merge,
# 'Dot' : _convert_merge,
'Dot' : _convert_merge,
'Permute' : _convert_permute,
# 'Embedding' : _convert_embedding,
# 'RepeatVector' : _convert_repeat_vector,
Expand Down
15 changes: 14 additions & 1 deletion tests/python/frontend/keras/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,24 @@ def test_forward_merge():
keras.layers.Average(),
keras.layers.Concatenate()]
for merge_func in merge_funcs:
if isinstance(merge_func, keras.layers.merge.Subtract):
if isinstance(merge_func, (keras.layers.merge.Subtract, keras.layers.merge.Dot)):
out = merge_func([x, y])
else:
out = merge_func([x, y, z])
keras_model = keras.models.Model(data, out)
verify_keras_frontend(keras_model)

def test_forward_merge_dot():
data1 = keras.layers.Input(shape=(2, 2))
data2 = keras.layers.Input(shape=(2, 2))
merge_funcs = [keras.layers.Dot(axes=[1, 2]),
keras.layers.Dot(axes=[2, 1]),
keras.layers.Dot(axes=1),
keras.layers.Dot(axes=2)]
for merge_func in merge_funcs:
out = merge_func([data1, data2])
keras_model = keras.models.Model([data1, data2], out)
verify_keras_frontend(keras_model, need_transpose=False)

def test_forward_activations():
data = keras.layers.Input(shape=(32, 32, 3))
Expand Down Expand Up @@ -276,7 +287,9 @@ def test_forward_mobilenet():


if __name__ == '__main__':

test_forward_merge()
test_forward_merge_dot()
test_forward_activations()
test_forward_dense()
test_forward_permute()
Expand Down

0 comments on commit 5f94292

Please sign in to comment.