-
Notifications
You must be signed in to change notification settings - Fork 667
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GPU RNNT Loss #1483
Add GPU RNNT Loss #1483
Conversation
This pull request was exported from Phabricator. Differential Revision: D28128853 |
0b8f57a
to
7f052b6
Compare
Summary: In pytorch#1479, we added support for CPU RNNT loss. This PR adds a GPU version of RNNT loss Differential Revision: D28128853 fbshipit-source-id: 3c610c7e9c3dda3fb309586d5dc71397752cd2e0
7f052b6
to
d8cfef6
Compare
d8cfef6
to
f6e783d
Compare
dbe81db
to
22f77ae
Compare
22f77ae
to
ed4f5b6
Compare
@malfet @seemethere Could I get your input on supporting custom CUDA kernel in torchaudio? One of our features (RNN transducer loss) for our the upcoming release involves a custom CUDA kernel, but this isn't something that torchaudio currently supports. I have a minimal working CUDA build in this PR that compiles successfully on AWS and in the CircleCI builds (mac builds are failing on master for an unrelated reason), but I’m not very familiar with CUDA builds. Are there any additional flags that we should include (torchvision includes these), or anything platform specific (ex/ for Windows) that we should add? |
30bf693
to
a31258f
Compare
@@ -0,0 +1,73 @@ | |||
#include <THC/THC.h> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you really need to include THC?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can replace THC with something like <c10/cuda/CUDAStream.h>
to get the same result -- is there a standard for supporting the getCurrentCUDAStream
operator?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Including <c10/cuda/CUDAStream.h> is fine. THC is legacy and on its way out.
a31258f
to
25c286e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, LGTM!
from .utils import skipIfNoTransducer | ||
|
||
|
||
@skipIfNoTransducer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: just realized the name has not been standardized from Transducer to RNNT, but this can be changed after this PR
Summary: In #1479, we added support for CPU RNNT loss. This PR adds a GPU version of RNNT loss
Differential Revision: D28128853