Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate transducer input checks to C++ #1391

Merged

Conversation

carolineechen
Copy link
Contributor

Move input checks for RNNT from Python to C++

RNN Transducer Loss Issue: #1240

@carolineechen carolineechen force-pushed the migrate_transducer_input_checks branch from 2d15d0f to ce61ffe Compare March 15, 2021 15:16
torchaudio/csrc/transducer.cpp Outdated Show resolved Hide resolved
torchaudio/csrc/transducer.cpp Outdated Show resolved Hide resolved
Copy link
Contributor

@vincentqb vincentqb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this! I've noted a few small changes (and some things we'll think about in separate PRs)

torchaudio/csrc/transducer.cpp Outdated Show resolved Hide resolved
Comment on lines 31 to 36
TORCH_CHECK(
input_lengths.size(0) == acts.size(0),
"each output sequence must have a length");
TORCH_CHECK(
label_lengths.size(0) == acts.size(0),
"each example must have a label length");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's be explicit and consistent with python naming:

      "batch dimension mismatch between acts and act_lens: each example must have a length"
      "batch dimension mismatch between acts and label_lens: each example must have a label length"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(As follow-up to this PR: we'll change the names of the variables in the C++ function to match those of the python function:

"""
        Args:
            acts (Tensor): Tensor of dimension (batch, time, label, class) containing output from network
                before applying ``torch.nn.functional.log_softmax``.
            labels (Tensor): Tensor of dimension (batch, max label length) containing the labels padded by zero
            act_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence
            label_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence
        """

and, also, I just realized the two last descriptions are the same in the python documentation :)

"""
        Args:
            acts (Tensor): Tensor of dimension (batch, time, label, class) containing output *sequence* from network
                before applying ``torch.nn.functional.log_softmax``.
            labels (Tensor): Tensor of dimension (batch, max label length) containing the labels padded by zero
            act_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence
            label_lens (Tensor): Tensor of dimension (batch) containing the length of each *label*
        """

added * for suggested change.)

torchaudio/csrc/transducer.cpp Outdated Show resolved Hide resolved
torchaudio/csrc/transducer.cpp Outdated Show resolved Hide resolved
torchaudio/csrc/transducer.cpp Show resolved Hide resolved
int maxT = acts.size(1);
int maxU = acts.size(2);
int minibatch_size = acts.size(0);
int alphabet_size = acts.size(3);

TORCH_CHECK(
at::max(input_lengths).item().toInt() == maxT, "input length mismatch");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Follow-up beyond this PR: let's improve the readability here: "The maximum length of a sequence in acts must be equal to the maximal value given in act_lens" ?)

at::max(input_lengths).item().toInt() == maxT, "input length mismatch");
TORCH_CHECK(
at::max(label_lengths).item().toInt() + 1 == maxU,
"output length mismatch");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Follow-up beyond this PR: we'll want to improve this message too)

@carolineechen carolineechen force-pushed the migrate_transducer_input_checks branch from dbaf97c to 21b5abf Compare March 15, 2021 23:12
@carolineechen carolineechen force-pushed the migrate_transducer_input_checks branch from 21b5abf to 9b72d80 Compare March 15, 2021 23:14
@carolineechen carolineechen merged commit f06074a into pytorch:master Mar 16, 2021
@carolineechen carolineechen deleted the migrate_transducer_input_checks branch March 16, 2021 13:45
mthrok pushed a commit to mthrok/audio that referenced this pull request Dec 13, 2022
* Update index.rst

* Update layout.html
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants