Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve type handling in PyTorch frontend #5834

Merged
merged 3 commits into from
Jun 22, 2020

Conversation

t-vi
Copy link
Contributor

@t-vi t-vi commented Jun 17, 2020

  • Use/check type information from graph for inputs if available.
  • Allow user to set default dtype (default to float32 for sanity and compatibility).
  • Implement type promotion to follow PyTorch mechanism. This includes fixing the handling of many "Scalar" overloads in PyTorch binary ops.
  • Use dtype of input for consts.
  • Fix arange/linspace type semantics.
  • Added support for traced functions. (Because it really is about the "self" input handling.)

@masahi
Copy link
Member

masahi commented Jun 18, 2020

@siju-samuel please help review this PR.

@siju-samuel
Copy link
Member

@t-vi please rebase.

t-vi added 2 commits June 19, 2020 08:31
- Use type information from graph for inputs if available. Check
  against shape information from graph if available.
- Allow user to set default dtype (default to float32 for sanity and
  compatibility).
- Implement type promotion to follow PyTorch mechanism. This includes
  fixing the handling of many "Scalar" overloads in PyTorch binary ops.
- Fix arange/linspace type semantics.
- Added support for traced functions. (Because it really is about the
  "self" input handling.)

Aside from adding an optional default_dtype keyword argument, this does not
change the signature/requirements of from_pytorch.
and address other review comments. Thank you @siju-samuel
@t-vi t-vi force-pushed the pytorch_frontend_type_fix branch from 94b3952 to 7e752f2 Compare June 19, 2020 06:31
@t-vi
Copy link
Contributor Author

t-vi commented Jun 19, 2020

Rebased.

Copy link
Member

@masahi masahi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a good clean up, thanks

@t-vi
Copy link
Contributor Author

t-vi commented Jun 19, 2020

And the quantization test fails by a small margin. :/ I'll look into it some more. I wonder whether a better strategy for the lost types would be to run the ops on the fly instead...

@t-vi
Copy link
Contributor Author

t-vi commented Jun 19, 2020

I have to ask about the test criterion for qnn_test.py::test_serialized_modules.
It seems to pass here locally (with PyTorch 1.5 cpu wheel).
It gets ~70% identical, but so I would think that a more natural test is

    num_identical = np.sum(np.abs(tvm_result - pt_result) < 1e-2)

for me the outputs are 0.5ish and so 1e-2 is about (2**-6), i.e. to 2x-4x of quantization accuracy. Then I get 100% match for the criterion.
This might not work well for complex models, but for this simple one it should work. What do you think?

@masahi
Copy link
Member

masahi commented Jun 19, 2020

That 0.2 threshold was chosen arbitrary without a good reason. If you come up with a principled crieteria it is great, as long as it won't get flaky.

Other quantized tests also need a good test criteria (at the moment the outputs are not tested at all, due to the difficulty in comparing TVM and PyTorch outputs for end to end qmodels).

@t-vi
Copy link
Contributor Author

t-vi commented Jun 19, 2020

Looks like the CI likes the revised test criterion. 🙂

@t-vi
Copy link
Contributor Author

t-vi commented Jun 19, 2020

@masahi @siju-samuel Thank you for your reviews so far. Anything else I can do to help this along?

@t-vi t-vi requested review from siju-samuel and masahi June 19, 2020 18:45
@masahi
Copy link
Member

masahi commented Jun 19, 2020

@siju-samuel You can merge this PR with your approval. I'm blocked by your change request.

@siju-samuel siju-samuel merged commit 4eb49f0 into apache:master Jun 22, 2020
@siju-samuel
Copy link
Member

Thanks for the cleanup @t-vi @masahi. This PR is merged.

trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Jun 30, 2020
* Improve type handling in PyTorch frontend

- Use type information from graph for inputs if available. Check
  against shape information from graph if available.
- Allow user to set default dtype (default to float32 for sanity and
  compatibility).
- Implement type promotion to follow PyTorch mechanism. This includes
  fixing the handling of many "Scalar" overloads in PyTorch binary ops.
- Fix arange/linspace type semantics.
- Added support for traced functions. (Because it really is about the
  "self" input handling.)

Aside from adding an optional default_dtype keyword argument, this does not
change the signature/requirements of from_pytorch.

* Fix scalar detection using numpy.isscalar

and address other review comments. Thank you @siju-samuel

* refine test criteron on qnn_test::test_serialized_modules, fix bool conversion of const
zhiics pushed a commit to neo-ai/tvm that referenced this pull request Jul 2, 2020
* Improve type handling in PyTorch frontend

- Use type information from graph for inputs if available. Check
  against shape information from graph if available.
- Allow user to set default dtype (default to float32 for sanity and
  compatibility).
- Implement type promotion to follow PyTorch mechanism. This includes
  fixing the handling of many "Scalar" overloads in PyTorch binary ops.
- Fix arange/linspace type semantics.
- Added support for traced functions. (Because it really is about the
  "self" input handling.)

Aside from adding an optional default_dtype keyword argument, this does not
change the signature/requirements of from_pytorch.

* Fix scalar detection using numpy.isscalar

and address other review comments. Thank you @siju-samuel

* refine test criteron on qnn_test::test_serialized_modules, fix bool conversion of const
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants