Skip to content

Commit

Permalink
add warp transducer as submodule
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentqb committed Dec 17, 2020
1 parent d25a4dd commit f4afbbb
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 14 deletions.
5 changes: 5 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[submodule "third_party/warp_transducer"]
path = third_party/warp_transducer
# url = https://github.com/HawkAaron/warp-transducer
url = https://github.com/vincentqb/warp-transducer
branch = torchbind
71 changes: 59 additions & 12 deletions build_tools/setup_helpers/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
_ROOT_DIR = _THIS_DIR.parent.parent.resolve()
_CSRC_DIR = _ROOT_DIR / 'torchaudio' / 'csrc'
_TP_BASE_DIR = _ROOT_DIR / 'third_party'
_TP_TRANSDUCER_BASE_DIR = _ROOT_DIR / 'third_party' / 'warp_transducer'
_TP_INSTALL_DIR = _TP_BASE_DIR / 'install'


Expand Down Expand Up @@ -101,8 +102,8 @@ def _get_libraries():
return [] if _BUILD_SOX else ['sox']


def _build_third_party():
build_dir = str(_TP_BASE_DIR / 'build')
def _build_third_party(base_dir):
build_dir = str(base_dir / 'build')
os.makedirs(build_dir, exist_ok=True)
subprocess.run(
args=['cmake', '..'],
Expand All @@ -115,6 +116,57 @@ def _build_third_party():
check=True,
)

def _get_ext(debug):
return CppExtension(
_EXT_NAME,
_get_srcs(),
libraries=_get_libraries(),
include_dirs=_get_include_dirs(),
extra_compile_args=_get_eca(debug),
extra_objects=_get_extra_objects(),
extra_link_args=_get_ela(debug),
)


def _get_ext_rnnt(debug):
import torch
# from torch.utils.cpp_extension import CppExtension

extra_compile_args = ['-fPIC']
extra_compile_args += ['-std=c++14']
base_path = _TP_TRANSDUCER_BASE_DIR
default_warp_rnnt_path = base_path / "build"

if torch.cuda.is_available():

if "CUDA_HOME" not in os.environ:
raise RuntimeError("Please specify the environment variable CUDA_HOME")

enable_gpu = True

else:
print("Torch was not built with CUDA support, not building GPU extensions.")
enable_gpu = False

if enable_gpu:
extra_compile_args += ['-DWARPRNNT_ENABLE_GPU']

if "WARP_RNNT_PATH" in os.environ:
warp_rnnt_path = os.environ["WARP_RNNT_PATH"]
else:
warp_rnnt_path = default_warp_rnnt_path
include_dirs = [os.path.realpath(os.path.join(base_path, 'include'))]

return CppExtension(
name='_warp_transducer',
sources=[os.path.realpath(base_path / 'pytorch_binding' / 'src' / 'binding.cpp')],
include_dirs=include_dirs,
library_dirs=[os.path.realpath(warp_rnnt_path)],
libraries=['warprnnt'],
extra_link_args=['-Wl,-rpath,' + os.path.realpath(warp_rnnt_path)],
extra_compile_args=extra_compile_args
)


_EXT_NAME = 'torchaudio._torchaudio'

Expand All @@ -123,20 +175,15 @@ def get_ext_modules(debug=False):
if platform.system() == 'Windows':
return None
return [
CppExtension(
_EXT_NAME,
_get_srcs(),
libraries=_get_libraries(),
include_dirs=_get_include_dirs(),
extra_compile_args=_get_eca(debug),
extra_objects=_get_extra_objects(),
extra_link_args=_get_ela(debug),
),
_get_ext(debug),
_get_ext_rnnt(debug),
]


class BuildExtension(TorchBuildExtension):
def build_extension(self, ext):
if ext.name == _EXT_NAME and _BUILD_SOX:
_build_third_party()
_build_third_party(_TP_BASE_DIR)
if ext.name == "_warp_transducer":
_build_third_party(_TP_TRANSDUCER_BASE_DIR)
super().build_extension(ext)
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def run(self):
build_dirs = [
ROOT_DIR / 'build',
ROOT_DIR / 'third_party' / 'build',
ROOT_DIR / 'third_party' / 'warp_transducer' / 'build',
]
for path in build_dirs:
if path.exists():
Expand Down Expand Up @@ -83,7 +84,8 @@ def run(self):
packages=find_packages(exclude=["build*", "test*", "torchaudio.csrc*", "third_party*", "build_tools*"]),
ext_modules=setup_helpers.get_ext_modules(),
cmdclass={
'build_ext': setup_helpers.BuildExtension.with_options(no_python_abi_suffix=True)
'build_ext': setup_helpers.BuildExtension.with_options(no_python_abi_suffix=True),
'clean': clean,
},
install_requires=[pytorch_package_dep],
zip_safe=False,
Expand Down
51 changes: 51 additions & 0 deletions test/torchaudio_unittest/transducer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch

from torchaudio_unittest import common_utils
from torchaudio.transducer import RNNTLoss


class TransducerTester:
def test_basic_backward(self):
rnnt_loss = RNNTLoss()

acts = torch.FloatTensor(
[
[
[
[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.6, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.8, 0.1],
],
[
[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.1, 0.1],
[0.7, 0.1, 0.2, 0.1, 0.1],
],
]
]
)
labels = torch.IntTensor([[1, 2]])
act_length = torch.IntTensor([2])
label_length = torch.IntTensor([2])

acts = acts.to(self.device)
labels = labels.to(self.device)
act_length = act_length.to(self.device)
label_length = label_length.to(self.device)

acts = torch.autograd.Variable(acts, requires_grad=True)
labels = torch.autograd.Variable(labels)
act_length = torch.autograd.Variable(act_length)
label_length = torch.autograd.Variable(label_length)

loss = rnnt_loss(acts, labels, act_length, label_length)
loss.backward()


class CPUTransducerTester(TransducerTester, common_utils.PytorchTestCase):
device = "cpu"


@common_utils.skipIfNoCuda
class GPUTransducerTester(TransducerTester, common_utils.PytorchTestCase):
device = "cuda"
3 changes: 2 additions & 1 deletion torchaudio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
kaldi_io,
utils,
sox_effects,
transforms
transforms,
transducer,
)

USE_SOUNDFILE_LEGACY_INTERFACE = None
Expand Down
138 changes: 138 additions & 0 deletions torchaudio/transducer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import torch
import _warp_transducer as warp_rnnt
from torch.autograd import Function
from torch.nn import Module

__all__ = ["rnnt_loss", "RNNTLoss"]


class _RNNT(Function):
@staticmethod
def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction):
"""
acts: Tensor of (batch x seqLength x labelLength x outputDim) containing output from network
labels: 2 dimensional Tensor containing all the targets of the batch with zero padded
act_lens: Tensor of size (batch) containing size of each output sequence from the network
label_lens: Tensor of (batch) containing label length of each example
"""
is_cuda = acts.is_cuda

certify_inputs(acts, labels, act_lens, label_lens)

loss_func = warp_rnnt.gpu_rnnt if is_cuda else warp_rnnt.cpu_rnnt
grads = (
torch.zeros_like(acts) if acts.requires_grad else torch.zeros(0).to(acts)
)
minibatch_size = acts.size(0)
costs = torch.zeros(minibatch_size, dtype=acts.dtype)
loss_func(acts, labels, act_lens, label_lens, costs, grads, blank, 0)

if reduction in ["sum", "mean"]:
costs = costs.sum().unsqueeze_(-1)
if reduction == "mean":
costs /= minibatch_size
grads /= minibatch_size

costs = costs.to(acts.device)
ctx.grads = grads

return costs

@staticmethod
def backward(ctx, grad_output):
grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.grads)
return ctx.grads.mul_(grad_output), None, None, None, None, None


def rnnt_loss(acts, labels, act_lens, label_lens, blank=0, reduction="mean"):
"""RNN Transducer Loss
Args:
acts: Tensor of (batch x seqLength x labelLength x outputDim) containing output from network
labels: 2 dimensional Tensor containing all the targets of the batch with zero padded
act_lens: Tensor of size (batch) containing size of each output sequence from the network
label_lens: Tensor of (batch) containing label length of each example
blank (int, optional): blank label. Default: 0.
reduction (string, optional): Specifies the reduction to apply to the output:
'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
'mean': the output losses will be divided by the target lengths and
then the mean over the batch is taken. Default: 'mean'
"""
if not acts.is_cuda:
acts = torch.nn.functional.log_softmax(acts, -1)

return _RNNT.apply(acts, labels, act_lens, label_lens, blank, reduction)


class RNNTLoss(Module):
"""
Parameters:
blank (int, optional): blank label. Default: 0.
reduction (string, optional): Specifies the reduction to apply to the output:
'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
'mean': the output losses will be divided by the target lengths and
then the mean over the batch is taken. Default: 'mean'
"""

def __init__(self, blank=0, reduction="mean"):
super(RNNTLoss, self).__init__()
self.blank = blank
self.reduction = reduction
self.loss = _RNNT.apply

def forward(self, acts, labels, act_lens, label_lens):
"""
acts: Tensor of (batch x seqLength x labelLength x outputDim) containing output from network
labels: 2 dimensional Tensor containing all the targets of the batch with zero padded
act_lens: Tensor of size (batch) containing size of each output sequence from the network
label_lens: Tensor of (batch) containing label length of each example
"""
if not acts.is_cuda:
# NOTE manually done log_softmax for CPU version,
# log_softmax is computed within GPU version.
acts = torch.nn.functional.log_softmax(acts, -1)

return self.loss(acts, labels, act_lens, label_lens, self.blank, self.reduction)


def check_type(var, t, name):
if var.dtype is not t:
raise TypeError("{} must be {}".format(name, t))


def check_contiguous(var, name):
if not var.is_contiguous():
raise ValueError("{} must be contiguous".format(name))


def check_dim(var, dim, name):
if len(var.shape) != dim:
raise ValueError("{} must be {}D".format(name, dim))


def certify_inputs(log_probs, labels, lengths, label_lengths):
# check_type(log_probs, torch.float32, "log_probs")
check_type(labels, torch.int32, "labels")
check_type(label_lengths, torch.int32, "label_lengths")
check_type(lengths, torch.int32, "lengths")
check_contiguous(log_probs, "log_probs")
check_contiguous(labels, "labels")
check_contiguous(label_lengths, "label_lengths")
check_contiguous(lengths, "lengths")

if lengths.shape[0] != log_probs.shape[0]:
raise ValueError("must have a length per example.")
if label_lengths.shape[0] != log_probs.shape[0]:
raise ValueError("must have a label length per example.")

check_dim(log_probs, 4, "log_probs")
check_dim(labels, 2, "labels")
check_dim(lengths, 1, "lenghts")
check_dim(label_lengths, 1, "label_lenghts")
max_T = torch.max(lengths)
max_U = torch.max(label_lengths)
T, U = log_probs.shape[1:3]
if T != max_T:
raise ValueError("Input length mismatch")
if U != max_U + 1:
raise ValueError("Output length mismatch")

0 comments on commit f4afbbb

Please sign in to comment.