diff --git a/.circleci/torchscript_bc_test/common.sh b/.circleci/torchscript_bc_test/common.sh index 24ad5e45fb..6f64e7204c 100644 --- a/.circleci/torchscript_bc_test/common.sh +++ b/.circleci/torchscript_bc_test/common.sh @@ -66,5 +66,6 @@ build_master() { conda install -y -q pytorch "cpuonly" -c pytorch-nightly printf "* Installing torchaudio\n" cd "${_root_dir}" || exit 1 - BUILD_SOX=1 python setup.py clean install + git submodule update --init --recursive + BUILD_TRANSDUCER=1 BUILD_SOX=1 python setup.py clean install } diff --git a/.circleci/unittest/linux/scripts/install.sh b/.circleci/unittest/linux/scripts/install.sh index 27dec79251..e7543c762a 100755 --- a/.circleci/unittest/linux/scripts/install.sh +++ b/.circleci/unittest/linux/scripts/install.sh @@ -38,7 +38,7 @@ conda install -y -c "pytorch-${UPLOAD_CHANNEL}" pytorch ${cudatoolkit} # 2. Install torchaudio printf "* Installing torchaudio\n" -BUILD_SOX=1 python setup.py install +BUILD_TRANSDUCER=1 BUILD_SOX=1 python setup.py install # 3. Install Test tools printf "* Installing test tools\n" diff --git a/.circleci/unittest/linux/scripts/setup_env.sh b/.circleci/unittest/linux/scripts/setup_env.sh index 26292a00d5..f56b211ac9 100755 --- a/.circleci/unittest/linux/scripts/setup_env.sh +++ b/.circleci/unittest/linux/scripts/setup_env.sh @@ -43,6 +43,7 @@ conda activate "${env_dir}" pip --quiet install cmake ninja # 4. Buld codecs +git submodule update --init --recursive mkdir -p third_party/build ( cd third_party/build diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000..c01f8c91ad --- /dev/null +++ b/.gitmodules @@ -0,0 +1,4 @@ +[submodule "third_party/warp_transducer/submodule"] + path = third_party/transducer/submodule + url = https://github.com/HawkAaron/warp-transducer + ignore = dirty diff --git a/build_tools/setup_helpers/extension.py b/build_tools/setup_helpers/extension.py index 8e6cd337cb..07ee05b520 100644 --- a/build_tools/setup_helpers/extension.py +++ b/build_tools/setup_helpers/extension.py @@ -20,20 +20,21 @@ _TP_INSTALL_DIR = _TP_BASE_DIR / 'install' -def _get_build_sox(): - val = os.environ.get('BUILD_SOX', '0') +def _get_build(var): + val = os.environ.get(var, '0') trues = ['1', 'true', 'TRUE', 'on', 'ON', 'yes', 'YES'] falses = ['0', 'false', 'FALSE', 'off', 'OFF', 'no', 'NO'] if val in trues: return True if val not in falses: print( - f'WARNING: Unexpected environment variable value `BUILD_SOX={val}`. ' + f'WARNING: Unexpected environment variable value `{var}={val}`. ' f'Expected one of {trues + falses}') return False -_BUILD_SOX = _get_build_sox() +_BUILD_SOX = _get_build("BUILD_SOX") +_BUILD_TRANSDUCER = _get_build("BUILD_TRANSDUCER") def _get_eca(debug): @@ -42,6 +43,8 @@ def _get_eca(debug): eca += ["-O0", "-g"] else: eca += ["-O3"] + if _BUILD_TRANSDUCER: + eca += ['-DBUILD_TRANSDUCER'] return eca @@ -67,6 +70,8 @@ def _get_include_dirs(): ] if _BUILD_SOX: dirs.append(str(_TP_INSTALL_DIR / 'include')) + if _BUILD_TRANSDUCER: + dirs.append(str(_TP_BASE_DIR / 'transducer' / 'submodule' / 'include')) return dirs @@ -94,6 +99,8 @@ def _get_extra_objects(): ] for lib in libs: objs.append(str(_TP_INSTALL_DIR / 'lib' / lib)) + if _BUILD_TRANSDUCER: + objs.append(str(_TP_BASE_DIR / 'build' / 'transducer' / 'libwarprnnt.a')) return objs diff --git a/packaging/build_wheel.sh b/packaging/build_wheel.sh index ad45a08d59..957b41e048 100755 --- a/packaging/build_wheel.sh +++ b/packaging/build_wheel.sh @@ -15,5 +15,5 @@ if [[ "$OSTYPE" == "msys" ]]; then python_tag="$(echo "cp$PYTHON_VERSION" | tr -d '.')" python setup.py bdist_wheel --plat-name win_amd64 --python-tag $python_tag else - BUILD_SOX=1 python setup.py bdist_wheel + BUILD_TRANSDUCER=1 BUILD_SOX=1 python setup.py bdist_wheel fi diff --git a/packaging/pkg_helpers.bash b/packaging/pkg_helpers.bash index 3316f789e2..635bfa3a14 100644 --- a/packaging/pkg_helpers.bash +++ b/packaging/pkg_helpers.bash @@ -103,6 +103,7 @@ setup_macos() { # # Usage: setup_env 0.2.0 setup_env() { + git submodule update --init --recursive setup_cuda setup_build_version "$1" setup_macos diff --git a/packaging/torchaudio/build.sh b/packaging/torchaudio/build.sh index 99c17b6913..88bbfce375 100644 --- a/packaging/torchaudio/build.sh +++ b/packaging/torchaudio/build.sh @@ -1,4 +1,4 @@ #!/usr/bin/env bash set -ex -BUILD_SOX=1 python setup.py install --single-version-externally-managed --record=record.txt +BUILD_TRANSDUCER=1 BUILD_SOX=1 python setup.py install --single-version-externally-managed --record=record.txt diff --git a/test/torchaudio_unittest/transducer_test.py b/test/torchaudio_unittest/transducer_test.py new file mode 100644 index 0000000000..3d8dc0683d --- /dev/null +++ b/test/torchaudio_unittest/transducer_test.py @@ -0,0 +1,276 @@ +import torch +from torchaudio.prototype.transducer import RNNTLoss + +from torchaudio_unittest import common_utils + + +def get_data_basic(device): + # Example provided + # in 6f73a2513dc784c59eec153a45f40bc528355b18 + # of https://github.com/HawkAaron/warp-transducer + + acts = torch.tensor( + [ + [ + [ + [0.1, 0.6, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.6, 0.1, 0.1], + [0.1, 0.1, 0.2, 0.8, 0.1], + ], + [ + [0.1, 0.6, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.2, 0.1, 0.1], + [0.7, 0.1, 0.2, 0.1, 0.1], + ], + ] + ], + dtype=torch.float, + ) + labels = torch.tensor([[1, 2]], dtype=torch.int) + act_length = torch.tensor([2], dtype=torch.int) + label_length = torch.tensor([2], dtype=torch.int) + + acts = acts.to(device) + labels = labels.to(device) + act_length = act_length.to(device) + label_length = label_length.to(device) + + acts.requires_grad_(True) + + return acts, labels, act_length, label_length + + +def get_data_B2_T4_U3_D3(dtype=torch.float32, device="cpu"): + # Test from D21322854 + + logits = torch.tensor( + [ + 0.065357, + 0.787530, + 0.081592, + 0.529716, + 0.750675, + 0.754135, + 0.609764, + 0.868140, + 0.622532, + 0.668522, + 0.858039, + 0.164539, + 0.989780, + 0.944298, + 0.603168, + 0.946783, + 0.666203, + 0.286882, + 0.094184, + 0.366674, + 0.736168, + 0.166680, + 0.714154, + 0.399400, + 0.535982, + 0.291821, + 0.612642, + 0.324241, + 0.800764, + 0.524106, + 0.779195, + 0.183314, + 0.113745, + 0.240222, + 0.339470, + 0.134160, + 0.505562, + 0.051597, + 0.640290, + 0.430733, + 0.829473, + 0.177467, + 0.320700, + 0.042883, + 0.302803, + 0.675178, + 0.569537, + 0.558474, + 0.083132, + 0.060165, + 0.107958, + 0.748615, + 0.943918, + 0.486356, + 0.418199, + 0.652408, + 0.024243, + 0.134582, + 0.366342, + 0.295830, + 0.923670, + 0.689929, + 0.741898, + 0.250005, + 0.603430, + 0.987289, + 0.592606, + 0.884672, + 0.543450, + 0.660770, + 0.377128, + 0.358021, + ], + dtype=dtype, + ).reshape(2, 4, 3, 3) + + targets = torch.tensor([[1, 2], [1, 1]], dtype=torch.int32) + src_lengths = torch.tensor([4, 4], dtype=torch.int32) + tgt_lengths = torch.tensor([2, 2], dtype=torch.int32) + + blank = 0 + + ref_costs = torch.tensor([4.2806528590890736, 3.9384369822503591], dtype=dtype) + + ref_gradients = torch.tensor( + [ + -0.186844, + -0.062555, + 0.249399, + -0.203377, + 0.202399, + 0.000977, + -0.141016, + 0.079123, + 0.061893, + -0.011552, + -0.081280, + 0.092832, + -0.154257, + 0.229433, + -0.075176, + -0.246593, + 0.146405, + 0.100188, + -0.012918, + -0.061593, + 0.074512, + -0.055986, + 0.219831, + -0.163845, + -0.497627, + 0.209240, + 0.288387, + 0.013605, + -0.030220, + 0.016615, + 0.113925, + 0.062781, + -0.176706, + -0.667078, + 0.367659, + 0.299419, + -0.356344, + -0.055347, + 0.411691, + -0.096922, + 0.029459, + 0.067463, + -0.063518, + 0.027654, + 0.035863, + -0.154499, + -0.073942, + 0.228441, + -0.166790, + -0.000088, + 0.166878, + -0.172370, + 0.105565, + 0.066804, + 0.023875, + -0.118256, + 0.094381, + -0.104707, + -0.108934, + 0.213642, + -0.369844, + 0.180118, + 0.189726, + 0.025714, + -0.079462, + 0.053748, + 0.122328, + -0.238789, + 0.116460, + -0.598687, + 0.302203, + 0.296484, + ], + dtype=dtype, + ).reshape(2, 4, 3, 3) + + logits.requires_grad_(True) + logits = logits.to(device) + + def grad_hook(grad): + logits.saved_grad = grad.clone() + + logits.register_hook(grad_hook) + + data = { + "logits": logits, + "targets": targets, + "src_lengths": src_lengths, + "tgt_lengths": tgt_lengths, + "blank": blank, + } + + return data, ref_costs, ref_gradients + + +def compute_with_pytorch_transducer(data): + costs = RNNTLoss(blank=data["blank"], reduction="none")( + acts=data["logits"], + labels=data["targets"], + act_lens=data["src_lengths"], + label_lens=data["tgt_lengths"], + ) + + loss = torch.sum(costs) + loss.backward() + costs = costs.cpu() + gradients = data["logits"].saved_grad.cpu() + return costs, gradients + + +class TransducerTester: + def test_basic_fp16_error(self): + rnnt_loss = RNNTLoss() + acts, labels, act_length, label_length = get_data_basic(self.device) + acts = acts.to(torch.float16) + # RuntimeError raised by log_softmax before reaching transducer's bindings + self.assertRaises( + RuntimeError, rnnt_loss, acts, labels, act_length, label_length + ) + + def test_basic_backward(self): + rnnt_loss = RNNTLoss() + acts, labels, act_length, label_length = get_data_basic(self.device) + loss = rnnt_loss(acts, labels, act_length, label_length) + loss.backward() + + def test_costs_and_gradients_B2_T4_U3_D3_fp32(self): + + data, ref_costs, ref_gradients = get_data_B2_T4_U3_D3( + dtype=torch.float32, device=self.device + ) + logits_shape = data["logits"].shape + costs, gradients = compute_with_pytorch_transducer(data=data) + + atol, rtol = 1e-6, 1e-2 + self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol) + self.assertEqual(logits_shape, gradients.shape) + self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol) + + +@common_utils.skipIfNoExtension +class CPUTransducerTester(TransducerTester, common_utils.PytorchTestCase): + device = "cpu" diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index a787e853e7..a1a9c42b6a 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -88,3 +88,5 @@ ExternalProject_Add(libsox # See https://github.com/pytorch/audio/pull/1026 CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/build_codec_helper.sh ${CMAKE_CURRENT_SOURCE_DIR}/src/libsox/configure ${COMMON_ARGS} --with-lame --with-flac --with-mad --with-oggvorbis --without-alsa --without-coreaudio --without-png --without-oss --without-sndfile --with-opus --with-amrwb --with-amrnb --disable-openmp --without-sndio --without-pulseaudio ) + +add_subdirectory(transducer) diff --git a/third_party/transducer/CMakeLists.txt b/third_party/transducer/CMakeLists.txt new file mode 100755 index 0000000000..092cd536a0 --- /dev/null +++ b/third_party/transducer/CMakeLists.txt @@ -0,0 +1,38 @@ +CMAKE_MINIMUM_REQUIRED(VERSION 3.5) + +PROJECT(rnnt_release) + +SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2") + +IF(APPLE) + ADD_DEFINITIONS(-DAPPLE) +ENDIF() + +INCLUDE_DIRECTORIES(submodule/include) + +SET(CMAKE_POSITION_INDEPENDENT_CODE ON) + +ADD_DEFINITIONS(-DRNNT_DISABLE_OMP) + +IF(APPLE) + EXEC_PROGRAM(uname ARGS -v OUTPUT_VARIABLE DARWIN_VERSION) + STRING(REGEX MATCH "[0-9]+" DARWIN_VERSION ${DARWIN_VERSION}) + MESSAGE(STATUS "DARWIN_VERSION=${DARWIN_VERSION}") + + # for el capitain have to use rpath + IF(DARWIN_VERSION LESS 15) + SET(CMAKE_SKIP_RPATH TRUE) + ENDIF() + +ELSE() + # always skip for linux + SET(CMAKE_SKIP_RPATH TRUE) +ENDIF() + +ADD_LIBRARY(warprnnt STATIC submodule/src/rnnt_entrypoint.cpp) + +INSTALL(TARGETS warprnnt + LIBRARY DESTINATION "lib" + ARCHIVE DESTINATION "lib") + +INSTALL(FILES submodule/include/rnnt.h DESTINATION "submodule/include") diff --git a/third_party/transducer/submodule b/third_party/transducer/submodule new file mode 160000 index 0000000000..f546575109 --- /dev/null +++ b/third_party/transducer/submodule @@ -0,0 +1 @@ +Subproject commit f546575109111c455354861a0567c8aa794208a2 diff --git a/torchaudio/csrc/register.cpp b/torchaudio/csrc/register.cpp index 0eb73f1daf..18ac86f388 100644 --- a/torchaudio/csrc/register.cpp +++ b/torchaudio/csrc/register.cpp @@ -77,5 +77,19 @@ TORCH_LIBRARY(torchaudio, m) { m.def( "torchaudio::sox_effects_apply_effects_file", &torchaudio::sox_effects::apply_effects_file); + + ////////////////////////////////////////////////////////////////////////////// + // transducer.cpp + ////////////////////////////////////////////////////////////////////////////// + #ifdef BUILD_TRANSDUCER + m.def("rnnt_loss(Tensor acts," + "Tensor labels," + "Tensor input_lengths," + "Tensor label_lengths," + "Tensor costs," + "Tensor grads," + "int blank_label," + "int num_threads) -> int"); + #endif } #endif diff --git a/torchaudio/csrc/transducer.cpp b/torchaudio/csrc/transducer.cpp new file mode 100644 index 0000000000..2d2b7a8b51 --- /dev/null +++ b/torchaudio/csrc/transducer.cpp @@ -0,0 +1,82 @@ +#ifdef BUILD_TRANSDUCER + +#include +#include +#include +#include + +#include +#include "rnnt.h" + +int64_t cpu_rnnt_loss(torch::Tensor acts, + torch::Tensor labels, + torch::Tensor input_lengths, + torch::Tensor label_lengths, + torch::Tensor costs, + torch::Tensor grads, + int64_t blank_label, + int64_t num_threads) { + + int maxT = acts.size(1); + int maxU = acts.size(2); + int minibatch_size = acts.size(0); + int alphabet_size = acts.size(3); + + rnntOptions options; + memset(&options, 0, sizeof(options)); + options.maxT = maxT; + options.maxU = maxU; + options.blank_label = blank_label; + options.batch_first = true; + options.loc = RNNT_CPU; + options.num_threads = num_threads; + + // have to use at least one + options.num_threads = std::max(options.num_threads, (unsigned int) 1); + + size_t cpu_size_bytes = 0; + switch (acts.scalar_type()) { + case torch::ScalarType::Float: + { + get_workspace_size(maxT, maxU, minibatch_size, + false, &cpu_size_bytes); + + std::vector cpu_workspace(cpu_size_bytes / sizeof(float), 0); + + compute_rnnt_loss(acts.data(), grads.data(), + labels.data(), label_lengths.data(), + input_lengths.data(), alphabet_size, + minibatch_size, costs.data(), + cpu_workspace.data(), options); + + return 0; + } + case torch::ScalarType::Double: + { + get_workspace_size(maxT, maxU, minibatch_size, + false, &cpu_size_bytes, + sizeof(double)); + + std::vector cpu_workspace(cpu_size_bytes / sizeof(double), 0); + + compute_rnnt_loss_fp64(acts.data(), grads.data(), + labels.data(), label_lengths.data(), + input_lengths.data(), alphabet_size, + minibatch_size, costs.data(), + cpu_workspace.data(), options); + + return 0; + } + default: + TORCH_CHECK(false, + std::string(__func__) + " not implemented for '" + toString(acts.scalar_type()) + "'" + ); + } + return -1; +} + +TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("rnnt_loss", &cpu_rnnt_loss); +} + +#endif diff --git a/torchaudio/prototype/__init__.py b/torchaudio/prototype/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchaudio/prototype/transducer.py b/torchaudio/prototype/transducer.py new file mode 100644 index 0000000000..6720ee9ea7 --- /dev/null +++ b/torchaudio/prototype/transducer.py @@ -0,0 +1,159 @@ +import torch +from torch.autograd import Function +from torch.nn import Module +from torchaudio._internal import ( + module_utils as _mod_utils, +) + +__all__ = ["rnnt_loss", "RNNTLoss"] + + +class _RNNT(Function): + @staticmethod + def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction): + """ + See documentation for RNNTLoss. + """ + + device = acts.device + check_inputs(acts, labels, act_lens, label_lens) + + acts = acts.to("cpu") + labels = labels.to("cpu") + act_lens = act_lens.to("cpu") + label_lens = label_lens.to("cpu") + + loss_func = torch.ops.torchaudio.rnnt_loss + + grads = torch.zeros_like(acts) + minibatch_size = acts.size(0) + costs = torch.zeros(minibatch_size, dtype=acts.dtype) + + loss_func(acts, labels, act_lens, label_lens, costs, grads, blank, 0) + + if reduction in ["sum", "mean"]: + costs = costs.sum().unsqueeze_(-1) + if reduction == "mean": + costs /= minibatch_size + grads /= minibatch_size + + costs = costs.to(device) + ctx.grads = grads.to(device) + + return costs + + @staticmethod + def backward(ctx, grad_output): + grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.grads) + return ctx.grads.mul_(grad_output), None, None, None, None, None + + +@_mod_utils.requires_module("torchaudio._torchaudio") +def rnnt_loss(acts, labels, act_lens, label_lens, blank=0, reduction="mean"): + """Compute the RNN Transducer Loss. + + The RNN Transducer loss (`Graves 2012 `__) extends the CTC loss by defining + a distribution over output sequences of all lengths, and by jointly modelling both input-output and output-output + dependencies. + + The implementation uses `warp-transducer `__. + + Args: + acts (Tensor): Tensor of dimension (batch, time, label, class) containing output from network + before applying ``torch.nn.functional.log_softmax``. + labels (Tensor): Tensor of dimension (batch, max label length) containing the labels padded by zero + act_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence + label_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence + blank (int): blank label. (Default: ``0``) + reduction (string): If ``'sum'``, the output losses will be summed. + If ``'mean'``, the output losses will be divided by the target lengths and + then the mean over the batch is taken. If ``'none'``, no reduction will be applied. + (Default: ``'mean'``) + """ + + # NOTE manually done log_softmax for CPU version, + # log_softmax is computed within GPU version. + acts = torch.nn.functional.log_softmax(acts, -1) + return _RNNT.apply(acts, labels, act_lens, label_lens, blank, reduction) + + +@_mod_utils.requires_module("torchaudio._torchaudio") +class RNNTLoss(Module): + """Compute the RNN Transducer Loss. + + The RNN Transducer loss (`Graves 2012 `__) extends the CTC loss by defining + a distribution over output sequences of all lengths, and by jointly modelling both input-output and output-output + dependencies. + + The implementation uses `warp-transducer `__. + + Args: + blank (int): blank label. (Default: ``0``) + reduction (string): If ``'sum'``, the output losses will be summed. + If ``'mean'``, the output losses will be divided by the target lengths and + then the mean over the batch is taken. If ``'none'``, no reduction will be applied. + (Default: ``'mean'``) + """ + + def __init__(self, blank=0, reduction="mean"): + super(RNNTLoss, self).__init__() + self.blank = blank + self.reduction = reduction + self.loss = _RNNT.apply + + def forward(self, acts, labels, act_lens, label_lens): + """ + Args: + acts (Tensor): Tensor of dimension (batch, time, label, class) containing output from network + before applying ``torch.nn.functional.log_softmax``. + labels (Tensor): Tensor of dimension (batch, max label length) containing the labels padded by zero + act_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence + label_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence + """ + + # NOTE manually done log_softmax for CPU version, + # log_softmax is computed within GPU version. + acts = torch.nn.functional.log_softmax(acts, -1) + return self.loss(acts, labels, act_lens, label_lens, self.blank, self.reduction) + + +def check_type(var, t, name): + if var.dtype is not t: + raise TypeError("{} must be {}".format(name, t)) + + +def check_contiguous(var, name): + if not var.is_contiguous(): + raise ValueError("{} must be contiguous".format(name)) + + +def check_dim(var, dim, name): + if len(var.shape) != dim: + raise ValueError("{} must be {}D".format(name, dim)) + + +def check_inputs(log_probs, labels, lengths, label_lengths): + check_type(labels, torch.int32, "labels") + check_type(label_lengths, torch.int32, "label_lengths") + check_type(lengths, torch.int32, "lengths") + check_contiguous(log_probs, "log_probs") + check_contiguous(labels, "labels") + check_contiguous(label_lengths, "label_lengths") + check_contiguous(lengths, "lengths") + + if lengths.shape[0] != log_probs.shape[0]: + raise ValueError("must have a length per example.") + if label_lengths.shape[0] != log_probs.shape[0]: + raise ValueError("must have a label length per example.") + + check_dim(log_probs, 4, "log_probs") + check_dim(labels, 2, "labels") + check_dim(lengths, 1, "lenghts") + check_dim(label_lengths, 1, "label_lenghts") + max_T = torch.max(lengths) + max_U = torch.max(label_lengths) + T, U = log_probs.shape[1:3] + if T != max_T: + raise ValueError("Input length mismatch") + if U != max_U + 1: + raise ValueError("Output length mismatch")