-
Notifications
You must be signed in to change notification settings - Fork 352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🐛 [Bug] Fallback for torch.nn.functional.one_hot fails #814
Comments
@chaoz-dev I tried this repro without any Torch-TensorRT code and I still get the same error. I do have a fix for the defaulting to FP32 issue however. |
inferred type. fixes: #814 Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
With regards to the fix, can you elaborate on "user settings"? Is this going to be from input type annotations during graph compilation? |
I just tried running the model in PyTorch and got the same error (just commented out the compile step) The fix was to fix an issue where the type map for partitioning wasn't populated properly in the case where we couldn't infer the type |
Ah gotcha gotcha 👍🏼 |
Updating the above script so that running
|
I'm still seeing the same error however:
|
Just tested this using NGC Seems like we're still inferring F32 here
Although it's possible the issue is that we're trying to compile just one op that's immediately at the beginning and end of the graph, so we're not falling back as we should and leaving the tensor alone (since it's of int64, which cannot be converted in TRT). |
Should I reopen this ticket or create a new one? |
I would say create a new one. Also have you tried setting the dtype of the input to int32? |
Also if its one unsupported op in the graph, the expected behavior is to return the original module back with no changes |
Bug Description
Fallback for
torch.nn.functional.one_hot
, whether automatic or forced, appears to fail with the following message:It appears that we attempt to pass floating point values to
one_hot
during compilation, which will fail asone_hot
only takes integer types.To Reproduce
Run the following:
Expected behavior
Expect the above to compile without issues.
Environment
conda
,pip
,libtorch
, source): condapython setup.py install
Additional context
The text was updated successfully, but these errors were encountered: