Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][ONNX] LSTM Support #4825

Merged
merged 6 commits into from
Feb 7, 2020
Merged

[Frontend][ONNX] LSTM Support #4825

merged 6 commits into from
Feb 7, 2020

Conversation

jwfromm
Copy link
Contributor

@jwfromm jwfromm commented Feb 6, 2020

This PR adds LSTM support to the relay Onnx frontend.

Besides adding the LSTM parser itself, we encountered an issue where for some Onnx operations (like LSTMs) arguments are optional. The current method for passing arguments to converters is just to pack them into a list however as some arguments are optional the position of each input is inconsistent. Instead, we should be using a dictionary mapping input names to their value. However, changing all inputs to a dictionary would require changing all the current operators and present problems with direct Onnx to relay conversions. Our workaround here is to add the onnx_input class that can be accessed as a list as we previously did or with input name dictionary style lookup.

@jwfromm
Copy link
Contributor Author

jwfromm commented Feb 6, 2020

@masahi, @soiferj, @mbrookhart can you take a look at this PR?

@jwfromm jwfromm requested a review from masahi February 6, 2020 01:47
@masahi masahi self-assigned this Feb 6, 2020
@soiferj
Copy link
Contributor

soiferj commented Feb 6, 2020

Awesome!! I was looking at implementing this myself yesterday. I’ll take a look as soon as possible. Thanks for sending the PR!

@masahi
Copy link
Member

masahi commented Feb 6, 2020

looks great! Will wait for @soiferj's review.

def _impl_v7(cls, inputs, attr, params):
# Unpack inputs, note that if optional and not provided then value will be None.
X = inputs[0]
W = inputs[1]
Copy link
Contributor

@soiferj soiferj Feb 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any case when the weights won’t be constant? If they’re constant, we can remove some operations from the graph and compute them here (like squeeze).

By constant, I mean we can call infer_value on it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in almost all cases it'd be safe to assume weights are constant. However, the fold constant pass in relay will eliminate all operations on the weights anyway. Since treating the weights as a non-constant is slightly more flexible I prefer it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Thanks a lot for the updates.

tvm.testing.assert_allclose(o_out, t_out, rtol=5e-3, atol=5e-3)


def test_lstm():
Copy link
Contributor

@soiferj soiferj Feb 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add a test where initial c and h states are set to something other than 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New tests for initial states and peephole weights are added. Glad you pointed this out since both those cases had some small bugs.

@@ -32,6 +32,55 @@
__all__ = ['from_onnx']


class onnx_input():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like this design - very sleek.

Copy link
Contributor

@mbrookhart mbrookhart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks Great, thanks Josh!

def _impl_v7(cls, inputs, attr, params):
# Unpack inputs, note that if optional and not provided then value will be None.
X = inputs[0]
W = inputs[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Thanks a lot for the updates.

@masahi masahi merged commit 75e9f5d into apache:master Feb 7, 2020
@masahi
Copy link
Member

masahi commented Feb 7, 2020

Thanks @jwfromm @soiferj @mbrookhart

anijain2305 pushed a commit to anijain2305/tvm that referenced this pull request Feb 10, 2020
* Initial version working and passing tests.

* WIP on supporting other activations.

* add support for multiple activation functions in lstm

* All tests working and code cleaned up.

* Undo import swap to avoid conflict with masahi.

* Added new tests and related bug fixes.

Co-authored-by: Matthew Brookhart <[email protected]>
alexwong pushed a commit to alexwong/tvm that referenced this pull request Feb 26, 2020
* Initial version working and passing tests.

* WIP on supporting other activations.

* add support for multiple activation functions in lstm

* All tests working and code cleaned up.

* Undo import swap to avoid conflict with masahi.

* Added new tests and related bug fixes.

Co-authored-by: Matthew Brookhart <[email protected]>
alexwong pushed a commit to alexwong/tvm that referenced this pull request Feb 28, 2020
* Initial version working and passing tests.

* WIP on supporting other activations.

* add support for multiple activation functions in lstm

* All tests working and code cleaned up.

* Undo import swap to avoid conflict with masahi.

* Added new tests and related bug fixes.

Co-authored-by: Matthew Brookhart <[email protected]>
zhiics pushed a commit to neo-ai/tvm that referenced this pull request Mar 2, 2020
* Initial version working and passing tests.

* WIP on supporting other activations.

* add support for multiple activation functions in lstm

* All tests working and code cleaned up.

* Undo import swap to avoid conflict with masahi.

* Added new tests and related bug fixes.

Co-authored-by: Matthew Brookhart <[email protected]>
@jwfromm jwfromm deleted the onnx_lstm branch April 12, 2023 15:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants