-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend][ONNX] LSTM Support #4825
Conversation
@masahi, @soiferj, @mbrookhart can you take a look at this PR? |
Awesome!! I was looking at implementing this myself yesterday. I’ll take a look as soon as possible. Thanks for sending the PR! |
looks great! Will wait for @soiferj's review. |
def _impl_v7(cls, inputs, attr, params): | ||
# Unpack inputs, note that if optional and not provided then value will be None. | ||
X = inputs[0] | ||
W = inputs[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any case when the weights won’t be constant? If they’re constant, we can remove some operations from the graph and compute them here (like squeeze).
By constant, I mean we can call infer_value
on it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think in almost all cases it'd be safe to assume weights are constant. However, the fold constant pass in relay will eliminate all operations on the weights anyway. Since treating the weights as a non-constant is slightly more flexible I prefer it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. Thanks a lot for the updates.
tvm.testing.assert_allclose(o_out, t_out, rtol=5e-3, atol=5e-3) | ||
|
||
|
||
def test_lstm(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also add a test where initial c and h states are set to something other than 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New tests for initial states and peephole weights are added. Glad you pointed this out since both those cases had some small bugs.
@@ -32,6 +32,55 @@ | |||
__all__ = ['from_onnx'] | |||
|
|||
|
|||
class onnx_input(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really like this design - very sleek.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks Great, thanks Josh!
def _impl_v7(cls, inputs, attr, params): | ||
# Unpack inputs, note that if optional and not provided then value will be None. | ||
X = inputs[0] | ||
W = inputs[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. Thanks a lot for the updates.
Thanks @jwfromm @soiferj @mbrookhart |
* Initial version working and passing tests. * WIP on supporting other activations. * add support for multiple activation functions in lstm * All tests working and code cleaned up. * Undo import swap to avoid conflict with masahi. * Added new tests and related bug fixes. Co-authored-by: Matthew Brookhart <[email protected]>
* Initial version working and passing tests. * WIP on supporting other activations. * add support for multiple activation functions in lstm * All tests working and code cleaned up. * Undo import swap to avoid conflict with masahi. * Added new tests and related bug fixes. Co-authored-by: Matthew Brookhart <[email protected]>
* Initial version working and passing tests. * WIP on supporting other activations. * add support for multiple activation functions in lstm * All tests working and code cleaned up. * Undo import swap to avoid conflict with masahi. * Added new tests and related bug fixes. Co-authored-by: Matthew Brookhart <[email protected]>
* Initial version working and passing tests. * WIP on supporting other activations. * add support for multiple activation functions in lstm * All tests working and code cleaned up. * Undo import swap to avoid conflict with masahi. * Added new tests and related bug fixes. Co-authored-by: Matthew Brookhart <[email protected]>
This PR adds LSTM support to the relay Onnx frontend.
Besides adding the LSTM parser itself, we encountered an issue where for some Onnx operations (like LSTMs) arguments are optional. The current method for passing arguments to converters is just to pack them into a list however as some arguments are optional the position of each input is inconsistent. Instead, we should be using a dictionary mapping input names to their value. However, changing all inputs to a dictionary would require changing all the current operators and present problems with direct Onnx to relay conversions. Our workaround here is to add the
onnx_input
class that can be accessed as a list as we previously did or with input name dictionary style lookup.