-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][Keras] Dot #3668
[Relay][Keras] Dot #3668
Conversation
python/tvm/relay/frontend/keras.py
Outdated
inexpr[1] = _op.transpose(inexpr[1], axes=[0, 2, 1]) | ||
else: | ||
raise tvm.error.OpAttributeUnimplemented( | ||
'Dot with {} is not supported.'.format(keras_layer.axes)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about the following?
axes=[axes,axes]
whenaxes
is an integer.- raise an error when axes[i] is neither 1 nor 2.
- transpose inexpr[i] when axes[i] is 1.
This can also handle axes=[1,1]
and raise an error against axes=0
or axes=3
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kazum thanks, I like the idea!
python/tvm/relay/frontend/keras.py
Outdated
else: | ||
raise tvm.error.OpAttributeUnimplemented( | ||
'Dot with {} is not supported.'.format(keras_layer.axes)) | ||
ret = _op.nn.batch_matmul(ret, inexpr[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code works only when the dimension of the input tensor is 3. Pleas add the check.
a305140
to
05fcb90
Compare
@kazum just updated the pr, thanks a lot for your comment! |
python/tvm/relay/frontend/keras.py
Outdated
@@ -461,6 +479,8 @@ def _convert_concat(inexpr, keras_layer, _): | |||
|
|||
def _convert_reshape(inexpr, keras_layer, _): | |||
_check_data_format(keras_layer) | |||
if len(keras_layer.target_shape) < 3: | |||
return _op.reshape(inexpr, newshape=(1, ) + keras_layer.target_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test where this change is necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kazum test added
data_1 = keras.layers.Input(shape=(32, 32, 3)) | ||
data_2 = keras.layers.Input(shape=(32, 32)) | ||
x_1 = keras.layers.Reshape(target_shape=(32, 32, 3))(data_1) | ||
x_2 = keras.layers.Reshape(target_shape=(32, 32))(data_2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test cannot be passed for shape (32, 3)
?
It looks like changing reshape is unnecessary to support dot. How about removing the change from this pr?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, it is not related to dot, it's a bug for reshape. Sure, let me remove it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kazum updated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks @yongwww !
* [Relay][Keras] Dot * fix reshape * fix comments
* [Relay][Keras] Dot * fix reshape * fix comments
* [Relay][Keras] Dot * fix reshape * fix comments
@Huyuwei