-
Notifications
You must be signed in to change notification settings - Fork 661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add RNN Transducer Loss for CPU #1137
Conversation
9d6589a
to
e2e6562
Compare
def rnnt_loss(acts, labels, act_lens, label_lens, blank=0, reduction="mean"): | ||
"""RNN Transducer Loss | ||
|
||
Args: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the documentation could be improved a bit. It could also be useful to reference the paper.
super().build_extension(ext) | ||
|
||
|
||
_TRANSDUCER_NAME = '_warp_transducer' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will get installed in global namespace, outside of torchaudio
package directory.
Please put it in torchaudio
package.
MESSAGE(STATUS "Building static library with GPU support") | ||
|
||
CUDA_ADD_LIBRARY(warprnnt STATIC submodule/src/rnnt_entrypoint.cu) | ||
IF (!Torch_FOUND) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If torch is not found, shouldn't it be failing?
self.reduction = reduction | ||
self.loss = _RNNT.apply | ||
|
||
def forward(self, acts, labels, act_lens, label_lens): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you don't want to copy-paste the docs from the functional you could reference it here within the documentation.
b6c4ce8
to
ca66151
Compare
Some TODOs:
Some follow-ups:
|
82b7186
to
456eefc
Compare
# Test if example provided in README runs | ||
# https://github.com/HawkAaron/warp-transducer | ||
|
||
acts = torch.FloatTensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use the factory function torch.tensor([xyz], dtype=torch.float)
instead of the type constructor. Same applies to IntTensor.
f96089b
to
299310c
Compare
U = data["tgt_lengths"][b] | ||
for t in range(gradients.shape[1]): | ||
for u in range(gradients.shape[2]): | ||
np.testing.assert_allclose( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.assertEqual should be preferred
f18105a
to
1d2c5db
Compare
Some more TODOs:
Some more follow-ups:
Error below also happens on master:
|
64c8220
to
32e3398
Compare
loss = rnnt_loss(acts, labels, act_length, label_length) | ||
loss.backward() | ||
|
||
def _test_costs_and_gradients( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be inlined since it only has one call-site and is pretty small (but that's not the reason to remove an abstraction necessarily).
32e3398
to
fddfbd1
Compare
} | ||
|
||
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { | ||
m.impl("rnnt_loss", &cpu_rnnt_loss); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vincentqb Can you define a proper namespace? torchaudio::<something>::rnnt_loss
I am not sure how you want to move on, but if you have a plan to add different type of rnnt, then more descriptive name would work better later, like warprnnt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding anonymous namespace in #1159 for the time being.
@vincentqb I update the followup description for things addressed in #1159 and #1161. Please stamp these PRs when you have time. For |
For C++ ABI issue ssee #880 |
* fdsa * Tutorial runs * clarify one scaler per convergence run * adjust sizes, dont run illustrative sections * satisfying ocd * MORE * fdsa * details * rephrase * fix formatting * move script to recipes * hopefully moved to recipes * fdsa * add amp_tutorial to toctree * amp_tutorial -> amp_recipe * looks like backtick highlights dont render in card_description * correct path for amp_recipe.html * arch notes and saving/restoring * formatting * fdsa * Clarify autograd-autocast interaction for custom ops * touchups Co-authored-by: Brian Johnson <[email protected]>
This pull request introduces
rnnt_loss
andRNNTLoss
as a prototype intorchaudio.prototype.transducer
using HawkAaron's warp-transducer.Follow-up work detailed in #1240.
cc @astaff, internal, #1099