Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GPU RNNT Loss #1483

Merged

Conversation

carolineechen
Copy link
Contributor

Summary: In #1479, we added support for CPU RNNT loss. This PR adds a GPU version of RNNT loss

Differential Revision: D28128853

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D28128853

@carolineechen carolineechen changed the base branch from fbsync to master April 30, 2021 21:16
@carolineechen carolineechen force-pushed the export-D28128853-to-fbsync branch from 0b8f57a to 7f052b6 Compare April 30, 2021 21:23
Summary: In pytorch#1479, we added support for CPU RNNT loss. This PR adds a GPU version of RNNT loss

Differential Revision: D28128853

fbshipit-source-id: 3c610c7e9c3dda3fb309586d5dc71397752cd2e0
@carolineechen carolineechen force-pushed the export-D28128853-to-fbsync branch from 7f052b6 to d8cfef6 Compare April 30, 2021 22:03
@carolineechen carolineechen force-pushed the export-D28128853-to-fbsync branch from d8cfef6 to f6e783d Compare April 30, 2021 22:08
@carolineechen carolineechen marked this pull request as draft April 30, 2021 23:09
@carolineechen carolineechen force-pushed the export-D28128853-to-fbsync branch 2 times, most recently from dbe81db to 22f77ae Compare May 4, 2021 16:33
@carolineechen carolineechen force-pushed the export-D28128853-to-fbsync branch from 22f77ae to ed4f5b6 Compare May 4, 2021 16:57
@carolineechen
Copy link
Contributor Author

@malfet @seemethere Could I get your input on supporting custom CUDA kernel in torchaudio?

One of our features (RNN transducer loss) for our the upcoming release involves a custom CUDA kernel, but this isn't something that torchaudio currently supports. I have a minimal working CUDA build in this PR that compiles successfully on AWS and in the CircleCI builds (mac builds are failing on master for an unrelated reason), but I’m not very familiar with CUDA builds. Are there any additional flags that we should include (torchvision includes these), or anything platform specific (ex/ for Windows) that we should add?

@carolineechen carolineechen force-pushed the export-D28128853-to-fbsync branch from 30bf693 to a31258f Compare May 5, 2021 18:37
build_tools/setup_helpers/extension.py Outdated Show resolved Hide resolved
@@ -0,0 +1,73 @@
#include <THC/THC.h>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you really need to include THC?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can replace THC with something like <c10/cuda/CUDAStream.h> to get the same result -- is there a standard for supporting the getCurrentCUDAStream operator?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Including <c10/cuda/CUDAStream.h> is fine. THC is legacy and on its way out.

@carolineechen carolineechen force-pushed the export-D28128853-to-fbsync branch from a31258f to 25c286e Compare May 5, 2021 20:14
@carolineechen carolineechen marked this pull request as ready for review May 5, 2021 20:14
@carolineechen carolineechen requested a review from vincentqb May 5, 2021 20:14
Copy link
Contributor

@vincentqb vincentqb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, LGTM!

from .utils import skipIfNoTransducer


@skipIfNoTransducer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: just realized the name has not been standardized from Transducer to RNNT, but this can be changed after this PR

@carolineechen carolineechen merged commit 5417e4f into pytorch:master May 6, 2021
@vincentqb vincentqb mentioned this pull request May 6, 2021
22 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants