Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Keras] Dot #3668

Merged
merged 3 commits into from
Aug 28, 2019
Merged

[Relay][Keras] Dot #3668

merged 3 commits into from
Aug 28, 2019

Conversation

yongwww
Copy link
Member

@yongwww yongwww commented Jul 30, 2019

@yongwww yongwww changed the title [Relay][Keras] Dot [WIP][Relay][Keras] Dot Jul 30, 2019
@yongwww yongwww changed the title [WIP][Relay][Keras] Dot [Relay][Keras] Dot Jul 30, 2019
@tqchen tqchen assigned Huyuwei and kazum and unassigned Huyuwei Aug 1, 2019
@tqchen
Copy link
Member

tqchen commented Aug 1, 2019

@kazum @Huyuwei can you help manage the PR?

inexpr[1] = _op.transpose(inexpr[1], axes=[0, 2, 1])
else:
raise tvm.error.OpAttributeUnimplemented(
'Dot with {} is not supported.'.format(keras_layer.axes))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about the following?

  • axes=[axes,axes] when axes is an integer.
  • raise an error when axes[i] is neither 1 nor 2.
  • transpose inexpr[i] when axes[i] is 1.

This can also handle axes=[1,1] and raise an error against axes=0 or axes=3.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kazum thanks, I like the idea!

else:
raise tvm.error.OpAttributeUnimplemented(
'Dot with {} is not supported.'.format(keras_layer.axes))
ret = _op.nn.batch_matmul(ret, inexpr[1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code works only when the dimension of the input tensor is 3. Pleas add the check.

@kazum kazum added status: need update need update based on feedbacks and removed status: need review labels Aug 6, 2019
@yongwww yongwww force-pushed the tensordot branch 2 times, most recently from a305140 to 05fcb90 Compare August 16, 2019 23:17
@yongwww
Copy link
Member Author

yongwww commented Aug 16, 2019

@kazum just updated the pr, thanks a lot for your comment!

@@ -461,6 +479,8 @@ def _convert_concat(inexpr, keras_layer, _):

def _convert_reshape(inexpr, keras_layer, _):
_check_data_format(keras_layer)
if len(keras_layer.target_shape) < 3:
return _op.reshape(inexpr, newshape=(1, ) + keras_layer.target_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test where this change is necessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kazum test added

data_1 = keras.layers.Input(shape=(32, 32, 3))
data_2 = keras.layers.Input(shape=(32, 32))
x_1 = keras.layers.Reshape(target_shape=(32, 32, 3))(data_1)
x_2 = keras.layers.Reshape(target_shape=(32, 32))(data_2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test cannot be passed for shape (32, 3)?

It looks like changing reshape is unnecessary to support dot. How about removing the change from this pr?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, it is not related to dot, it's a bug for reshape. Sure, let me remove it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kazum updated

Copy link
Contributor

@kazum kazum left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks @yongwww !

@kazum kazum merged commit 19b8b3a into apache:master Aug 28, 2019
@kazum kazum removed the status: need update need update based on feedbacks label Aug 28, 2019
wweic pushed a commit to wweic/tvm that referenced this pull request Sep 16, 2019
* [Relay][Keras] Dot

* fix reshape

* fix comments
wweic pushed a commit to wweic/tvm that referenced this pull request Sep 16, 2019
* [Relay][Keras] Dot

* fix reshape

* fix comments
wweic pushed a commit to neo-ai/tvm that referenced this pull request Sep 16, 2019
* [Relay][Keras] Dot

* fix reshape

* fix comments
@tqchen tqchen unassigned kazum and Huyuwei Nov 4, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants