-
Notifications
You must be signed in to change notification settings - Fork 664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Migrate transducer input checks to C++ #1391
Migrate transducer input checks to C++ #1391
Conversation
2d15d0f
to
ce61ffe
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this! I've noted a few small changes (and some things we'll think about in separate PRs)
torchaudio/csrc/transducer.cpp
Outdated
TORCH_CHECK( | ||
input_lengths.size(0) == acts.size(0), | ||
"each output sequence must have a length"); | ||
TORCH_CHECK( | ||
label_lengths.size(0) == acts.size(0), | ||
"each example must have a label length"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's be explicit and consistent with python naming:
"batch dimension mismatch between acts and act_lens: each example must have a length"
"batch dimension mismatch between acts and label_lens: each example must have a label length"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(As follow-up to this PR: we'll change the names of the variables in the C++ function to match those of the python function:
"""
Args:
acts (Tensor): Tensor of dimension (batch, time, label, class) containing output from network
before applying ``torch.nn.functional.log_softmax``.
labels (Tensor): Tensor of dimension (batch, max label length) containing the labels padded by zero
act_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence
label_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence
"""
and, also, I just realized the two last descriptions are the same in the python documentation :)
"""
Args:
acts (Tensor): Tensor of dimension (batch, time, label, class) containing output *sequence* from network
before applying ``torch.nn.functional.log_softmax``.
labels (Tensor): Tensor of dimension (batch, max label length) containing the labels padded by zero
act_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence
label_lens (Tensor): Tensor of dimension (batch) containing the length of each *label*
"""
added * for suggested change.)
int maxT = acts.size(1); | ||
int maxU = acts.size(2); | ||
int minibatch_size = acts.size(0); | ||
int alphabet_size = acts.size(3); | ||
|
||
TORCH_CHECK( | ||
at::max(input_lengths).item().toInt() == maxT, "input length mismatch"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Follow-up beyond this PR: let's improve the readability here: "The maximum length of a sequence in acts must be equal to the maximal value given in act_lens" ?)
at::max(input_lengths).item().toInt() == maxT, "input length mismatch"); | ||
TORCH_CHECK( | ||
at::max(label_lengths).item().toInt() + 1 == maxU, | ||
"output length mismatch"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Follow-up beyond this PR: we'll want to improve this message too)
dbaf97c
to
21b5abf
Compare
21b5abf
to
9b72d80
Compare
* Update index.rst * Update layout.html
Move input checks for RNNT from Python to C++
RNN Transducer Loss Issue: #1240