-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend][PyTorch] Cast to Long gets translated into cast to float32 #6300
Comments
Thanks, the cast to long was missing, #6301 should fix this. |
Thanks @masahi for the super quick PR! However the problem still persists. I think there is something else going on because now it casts to long and then to float 32 afterwards!
Perhaps the add only works over floats and is forcing a cast? No idea.. |
One think that I noticed is that I get a bunch of |
Unfortunately no. This is indeed an annoying issue. Even if you are tracing, we get this warning because it seems (?) the weight and bias parameters from torch are not typed. So most likely the warning comes from tensors corresponding to parameters of conv, dense, batch norm etc |
Ok I found the solution, but I think I am hitting some bug. So in the
and in the middle of the
(this is what it throwing the 2 casts and related exception). Interestingly, if I do:
(meaning forcing a cast at the beginning of the |
Could |
Can you check the signature of the translated function and see if
|
Sure! Any suggestion on how to do that? (I am still getting familiar with TVM, to some extent :) ) |
yes see my updated comment |
So it looks that it gets it right. |
yeah the error also says |
This is the full log in case:
|
So even if the tensor is probably typed, if I remove the initial cast to long the program will fail. |
This is the log if I remove the explicit cast.
|
Can you come up with a standalone repro script? I couldn't create a simple test that gives the same error. |
I can try. I have pushed the code in this Hummingbird's commit. The test file to run is |
Thanks, I'll try hummingbird. Compiling decision trees to TVM via PyTorch sounds very interesting! |
Thanks! Before were manually translating the PyTorch implementation into Relay, but long term we want to use your PyTorch frontend. Any suggestion or help (or contributions) are welcome! In our experiments with the Relay implementation of the models we were getting about 3x performance improvement against PyTorch, we hope to continue to see a similar gains with the frontend! |
ok found the problem. The node_offset is indeed giving an warning at Since This is the same problem I mentioned earlier about conv weight, bias etc. But since they are float32 anyway, the default type we picked for them doesn't do any harm. Your use case is the first one we met where one of parameters to your model is indeed an integer tensor. Even though the dtype of I'll look into a bit more on what's going on with typing of parameter tensors at the Torchscript level. We might end up solving "Untyped warning" problem of parameters. |
By comparing two Torchscript IRs, with or without With the explicit cast, the rhs input of
Without the explicit cast, the rhs input of
|
Interesting. Do you suggest to open an issue against pytorch? It looks to me that there is no clear solution (beside adding the explicit casting) because this is more a design problem with torchscript. Altough should tracing help? Like if we run records through the ops shluld be able to detect that indeed the tensor is of type long. |
I don't know if Torch people would consider this as a problem. Anyway I've just opened #6311 to workaround dtype of GetAttr. With that change, the explicit cast is no longer necessary. The test case I added there should demonstrate the problem you are having. |
Fantastic! This is great, thanks. I will test this and let you know but from your test case it looks that it solves our problem. Thanks! |
It worked, no more casts are needed and the warnings disappeared. |
#6311 is merged |
I have the following pytorch program:
when I compile it into TVM, I get the following interesting trace:
Apparently in line 14 the cast into long is translated into a cast into float32.
To reproduce is you can pull this branch,
pip install -e .[extra]
and run this test file. I can try to generate a minimal running example if it helps.The text was updated successfully, but these errors were encountered: