-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve type handling in PyTorch frontend #5834
Conversation
t-vi
commented
Jun 17, 2020
- Use/check type information from graph for inputs if available.
- Allow user to set default dtype (default to float32 for sanity and compatibility).
- Implement type promotion to follow PyTorch mechanism. This includes fixing the handling of many "Scalar" overloads in PyTorch binary ops.
- Use dtype of input for consts.
- Fix arange/linspace type semantics.
- Added support for traced functions. (Because it really is about the "self" input handling.)
0d53153
to
509f520
Compare
@siju-samuel please help review this PR. |
@t-vi please rebase. |
- Use type information from graph for inputs if available. Check against shape information from graph if available. - Allow user to set default dtype (default to float32 for sanity and compatibility). - Implement type promotion to follow PyTorch mechanism. This includes fixing the handling of many "Scalar" overloads in PyTorch binary ops. - Fix arange/linspace type semantics. - Added support for traced functions. (Because it really is about the "self" input handling.) Aside from adding an optional default_dtype keyword argument, this does not change the signature/requirements of from_pytorch.
and address other review comments. Thank you @siju-samuel
94b3952
to
7e752f2
Compare
Rebased. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like a good clean up, thanks
And the quantization test fails by a small margin. :/ I'll look into it some more. I wonder whether a better strategy for the lost types would be to run the ops on the fly instead... |
I have to ask about the test criterion for
for me the outputs are 0.5ish and so 1e-2 is about (2**-6), i.e. to 2x-4x of quantization accuracy. Then I get 100% match for the criterion. |
…onversion of const
That 0.2 threshold was chosen arbitrary without a good reason. If you come up with a principled crieteria it is great, as long as it won't get flaky. Other quantized tests also need a good test criteria (at the moment the outputs are not tested at all, due to the difficulty in comparing TVM and PyTorch outputs for end to end qmodels). |
Looks like the CI likes the revised test criterion. 🙂 |
@masahi @siju-samuel Thank you for your reviews so far. Anything else I can do to help this along? |
@siju-samuel You can merge this PR with your approval. I'm blocked by your change request. |
* Improve type handling in PyTorch frontend - Use type information from graph for inputs if available. Check against shape information from graph if available. - Allow user to set default dtype (default to float32 for sanity and compatibility). - Implement type promotion to follow PyTorch mechanism. This includes fixing the handling of many "Scalar" overloads in PyTorch binary ops. - Fix arange/linspace type semantics. - Added support for traced functions. (Because it really is about the "self" input handling.) Aside from adding an optional default_dtype keyword argument, this does not change the signature/requirements of from_pytorch. * Fix scalar detection using numpy.isscalar and address other review comments. Thank you @siju-samuel * refine test criteron on qnn_test::test_serialized_modules, fix bool conversion of const
* Improve type handling in PyTorch frontend - Use type information from graph for inputs if available. Check against shape information from graph if available. - Allow user to set default dtype (default to float32 for sanity and compatibility). - Implement type promotion to follow PyTorch mechanism. This includes fixing the handling of many "Scalar" overloads in PyTorch binary ops. - Fix arange/linspace type semantics. - Added support for traced functions. (Because it really is about the "self" input handling.) Aside from adding an optional default_dtype keyword argument, this does not change the signature/requirements of from_pytorch. * Fix scalar detection using numpy.isscalar and address other review comments. Thank you @siju-samuel * refine test criteron on qnn_test::test_serialized_modules, fix bool conversion of const