Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RNN Transducer Loss for CPU #1137

Merged
merged 4 commits into from
Jan 5, 2021
Merged

Conversation

vincentqb
Copy link
Contributor

@vincentqb vincentqb commented Dec 30, 2020

This pull request introduces rnnt_loss and RNNTLoss as a prototype in torchaudio.prototype.transducer using HawkAaron's warp-transducer.

  • The python interface remains currently the same as the original.
  • This has been tested as integrated within ESPNet here.

Follow-up work detailed in #1240.

cc @astaff, internal, #1099

@vincentqb vincentqb force-pushed the transducer-cpu branch 2 times, most recently from 9d6589a to e2e6562 Compare December 30, 2020 18:57
@cpuhrsch cpuhrsch changed the title Add RNN Transducer Loss Add RNN Transducer Loss for CPU Dec 30, 2020
def rnnt_loss(acts, labels, act_lens, label_lens, blank=0, reduction="mean"):
"""RNN Transducer Loss

Args:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the documentation could be improved a bit. It could also be useful to reference the paper.

super().build_extension(ext)


_TRANSDUCER_NAME = '_warp_transducer'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will get installed in global namespace, outside of torchaudio package directory.
Please put it in torchaudio package.

MESSAGE(STATUS "Building static library with GPU support")

CUDA_ADD_LIBRARY(warprnnt STATIC submodule/src/rnnt_entrypoint.cu)
IF (!Torch_FOUND)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If torch is not found, shouldn't it be failing?

torchaudio/__init__.py Outdated Show resolved Hide resolved
self.reduction = reduction
self.loss = _RNNT.apply

def forward(self, acts, labels, act_lens, label_lens):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you don't want to copy-paste the docs from the functional you could reference it here within the documentation.

@vincentqb
Copy link
Contributor Author

vincentqb commented Dec 31, 2020

Some TODOs:

  • Use TORCH_CHECK to raise an error instead of writing to standard output, here, e.g. used here.
  • Investigate the memory allocation for workspace, here.

Some follow-ups:

  • Move libsox to a third_party subfolder as suggested above.
  • Investigate using AT_DISPATCH_FLOATING_TYPES.
  • Migrate the checks to C++.
  • Add GPU implementation and compilation.
  • Patch the submodule to remove the pytorch deprecation warnings.
  • Refactor, see internal.

@vincentqb vincentqb force-pushed the transducer-cpu branch 4 times, most recently from 82b7186 to 456eefc Compare January 3, 2021 22:51
# Test if example provided in README runs
# https://github.com/HawkAaron/warp-transducer

acts = torch.FloatTensor(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use the factory function torch.tensor([xyz], dtype=torch.float) instead of the type constructor. Same applies to IntTensor.

U = data["tgt_lengths"][b]
for t in range(gradients.shape[1]):
for u in range(gradients.shape[2]):
np.testing.assert_allclose(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.assertEqual should be preferred

@vincentqb vincentqb force-pushed the transducer-cpu branch 2 times, most recently from f18105a to 1d2c5db Compare January 5, 2021 05:33
@vincentqb
Copy link
Contributor Author

vincentqb commented Jan 5, 2021

Some more TODOs:

  • Remove numpy from tests
  • Guard prototype csrc and extension

Some more follow-ups:

  • Pass along the DEBUG flag to cmake
  • Guard prototype python files by omitting them from torchaudio, see also comment
  • Guard building third party transducer even if not added as an extension
  • Enable building transducer in nightlies only, not release.
  • Build within same folders as libsox, comment and comment
  • Remove hardcoded O2/O3 optimization, see comment

Error below also happens on master:

conda_build.exceptions.DependencyNeedsBuildingError: Unsatisfiable dependencies for platform osx-64: {"python[version='>=2.7,<2.8.0a0|>=3.8,<3.9.0a0|>=3.5,<3.6.0a0|>=3.5']", "python[version='>=3.6,<3.7.0a0|>=3.7,<3.8.0a0|>=3.9,<3.10.0a0']", "python_abi[version='3.6.*|3.7.*',build='*_cp36m|*_cp37m']", "python[version='>=2.7,<2.8.0a0|>=3.6,<3.7.0a0|>=3.7,<3.8.0a0|>=3.9,<3.10.0a0|>=3.8,<3.9.0a0|>=3.5,<3.6.0a0']", 'python', 'python_abi=3.9[build=*_cp39]'}

cc comment above

@vincentqb vincentqb force-pushed the transducer-cpu branch 2 times, most recently from 64c8220 to 32e3398 Compare January 5, 2021 16:23
loss = rnnt_loss(acts, labels, act_length, label_length)
loss.backward()

def _test_costs_and_gradients(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be inlined since it only has one call-site and is pretty small (but that's not the reason to remove an abstraction necessarily).

@vincentqb vincentqb marked this pull request as ready for review January 5, 2021 19:05
@vincentqb vincentqb merged commit 6b07bcf into pytorch:master Jan 5, 2021
}

TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("rnnt_loss", &cpu_rnnt_loss);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vincentqb Can you define a proper namespace? torchaudio::<something>::rnnt_loss

I am not sure how you want to move on, but if you have a plan to add different type of rnnt, then more descriptive name would work better later, like warprnnt

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding anonymous namespace in #1159 for the time being.

@mthrok
Copy link
Collaborator

mthrok commented Jan 7, 2021

@vincentqb I update the followup description for things addressed in #1159 and #1161. Please stamp these PRs when you have time.

For Enable building transducer in nightlies only, disable for release. I am thinking to add master environment variable that has higher precedence than BUILD_TRANSDUCER, like DISABLE_PROTOTYPE and propagating it from CCI configuration.

This was referenced Jan 8, 2021
@mthrok
Copy link
Collaborator

mthrok commented Jan 9, 2021

For C++ ABI issue ssee #880

@vincentqb vincentqb mentioned this pull request Feb 4, 2021
22 tasks
mthrok pushed a commit to mthrok/audio that referenced this pull request Feb 26, 2021
* fdsa

* Tutorial runs

* clarify one scaler per convergence run

* adjust sizes, dont run illustrative sections

* satisfying ocd

* MORE

* fdsa

* details

* rephrase

* fix formatting

* move script to recipes

* hopefully moved to recipes

* fdsa

* add amp_tutorial to toctree

* amp_tutorial -> amp_recipe

* looks like backtick highlights dont render in card_description

* correct path for amp_recipe.html

* arch notes and saving/restoring

* formatting

* fdsa

* Clarify autograd-autocast interaction for custom ops

* touchups

Co-authored-by: Brian Johnson <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants