Skip to content

Commit

Permalink
Add RNN transducer loss.
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentqb committed Dec 31, 2020
1 parent cf11427 commit ca66151
Show file tree
Hide file tree
Showing 13 changed files with 595 additions and 0 deletions.
1 change: 1 addition & 0 deletions .circleci/torchscript_bc_test/common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,6 @@ build_master() {
conda install -y -q pytorch "cpuonly" -c pytorch-nightly
printf "* Installing torchaudio\n"
cd "${_root_dir}" || exit 1
git submodule update --init --recursive
BUILD_SOX=1 python setup.py clean install
}
1 change: 1 addition & 0 deletions .circleci/unittest/linux/scripts/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ conda activate "${env_dir}"
pip --quiet install cmake ninja

# 4. Buld codecs
git submodule update --init --recursive
mkdir -p third_party/build
(
cd third_party/build
Expand Down
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[submodule "third_party/warp_transducer/submodule"]
path = third_party/transducer/submodule
url = https://github.com/HawkAaron/warp-transducer
branch = master
2 changes: 2 additions & 0 deletions build_tools/setup_helpers/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def _get_srcs():
def _get_include_dirs():
dirs = [
str(_ROOT_DIR),
str(_TP_BASE_DIR / 'transducer' / 'submodule' / 'include'),
]
if _BUILD_SOX:
dirs.append(str(_TP_INSTALL_DIR / 'include'))
Expand Down Expand Up @@ -94,6 +95,7 @@ def _get_extra_objects():
]
for lib in libs:
objs.append(str(_TP_INSTALL_DIR / 'lib' / lib))
objs.append(str(_TP_BASE_DIR / 'build' / 'transducer' / 'libwarprnnt.a'))
return objs


Expand Down
1 change: 1 addition & 0 deletions packaging/pkg_helpers.bash
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ setup_macos() {
#
# Usage: setup_env 0.2.0
setup_env() {
git submodule update --init --recursive
setup_cuda
setup_build_version "$1"
setup_macos
Expand Down
291 changes: 291 additions & 0 deletions test/torchaudio_unittest/transducer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
import numpy as np
import torch

from torchaudio_unittest import common_utils
from torchaudio.prototype.transducer import RNNTLoss


def get_numpy_data_B2_T4_U3_D3(dtype=np.float32):
logits = np.array(
[
0.065357,
0.787530,
0.081592,
0.529716,
0.750675,
0.754135,
0.609764,
0.868140,
0.622532,
0.668522,
0.858039,
0.164539,
0.989780,
0.944298,
0.603168,
0.946783,
0.666203,
0.286882,
0.094184,
0.366674,
0.736168,
0.166680,
0.714154,
0.399400,
0.535982,
0.291821,
0.612642,
0.324241,
0.800764,
0.524106,
0.779195,
0.183314,
0.113745,
0.240222,
0.339470,
0.134160,
0.505562,
0.051597,
0.640290,
0.430733,
0.829473,
0.177467,
0.320700,
0.042883,
0.302803,
0.675178,
0.569537,
0.558474,
0.083132,
0.060165,
0.107958,
0.748615,
0.943918,
0.486356,
0.418199,
0.652408,
0.024243,
0.134582,
0.366342,
0.295830,
0.923670,
0.689929,
0.741898,
0.250005,
0.603430,
0.987289,
0.592606,
0.884672,
0.543450,
0.660770,
0.377128,
0.358021,
],
dtype=dtype,
).reshape(2, 4, 3, 3)

targets = np.array([[1, 2], [1, 1]], dtype=np.int32)
src_lengths = np.array([4, 4], dtype=np.int32)
tgt_lengths = np.array([2, 2], dtype=np.int32)

blank = 0

ref_costs = np.array([4.2806528590890736, 3.9384369822503591], dtype=dtype)

ref_gradients = np.array(
[
-0.186844,
-0.062555,
0.249399,
-0.203377,
0.202399,
0.000977,
-0.141016,
0.079123,
0.061893,
-0.011552,
-0.081280,
0.092832,
-0.154257,
0.229433,
-0.075176,
-0.246593,
0.146405,
0.100188,
-0.012918,
-0.061593,
0.074512,
-0.055986,
0.219831,
-0.163845,
-0.497627,
0.209240,
0.288387,
0.013605,
-0.030220,
0.016615,
0.113925,
0.062781,
-0.176706,
-0.667078,
0.367659,
0.299419,
-0.356344,
-0.055347,
0.411691,
-0.096922,
0.029459,
0.067463,
-0.063518,
0.027654,
0.035863,
-0.154499,
-0.073942,
0.228441,
-0.166790,
-0.000088,
0.166878,
-0.172370,
0.105565,
0.066804,
0.023875,
-0.118256,
0.094381,
-0.104707,
-0.108934,
0.213642,
-0.369844,
0.180118,
0.189726,
0.025714,
-0.079462,
0.053748,
0.122328,
-0.238789,
0.116460,
-0.598687,
0.302203,
0.296484,
],
dtype=dtype,
).reshape(2, 4, 3, 3)

data = {
"logits": logits,
"targets": targets,
"src_lengths": src_lengths,
"tgt_lengths": tgt_lengths,
"blank": blank,
}

return data, ref_costs, ref_gradients


def numpy_to_torch(data, device, requires_grad=True):

logits = torch.from_numpy(data["logits"])
targets = torch.from_numpy(data["targets"])
src_lengths = torch.from_numpy(data["src_lengths"])
tgt_lengths = torch.from_numpy(data["tgt_lengths"])

logits.requires_grad_(requires_grad)

logits = logits.to(device)

def grad_hook(grad):
logits.saved_grad = grad.clone()

logits.register_hook(grad_hook)

data["logits"] = logits
data["src_lengths"] = src_lengths
data["tgt_lengths"] = tgt_lengths
data["targets"] = targets

return data


def compute_with_pytorch_transducer(data):
costs = RNNTLoss(blank=data["blank"], reduction="none")(
acts=data["logits_sparse"] if "logits_sparse" in data else data["logits"],
labels=data["targets"],
act_lens=data["src_lengths"],
label_lens=data["tgt_lengths"],
)

loss = torch.sum(costs)
loss.backward()
costs = costs.cpu().data.numpy()
gradients = data["logits"].saved_grad.cpu().data.numpy()
return costs, gradients


class TransducerTester:
def test_basic_backward(self):
# Test if example provided in README runs
# https://github.com/HawkAaron/warp-transducer

rnnt_loss = RNNTLoss()

acts = torch.FloatTensor(
[
[
[
[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.6, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.8, 0.1],
],
[
[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.1, 0.1],
[0.7, 0.1, 0.2, 0.1, 0.1],
],
]
]
)
labels = torch.IntTensor([[1, 2]])
act_length = torch.IntTensor([2])
label_length = torch.IntTensor([2])

acts = acts.to(self.device)
labels = labels.to(self.device)
act_length = act_length.to(self.device)
label_length = label_length.to(self.device)

acts.requires_grad_(True)

loss = rnnt_loss(acts, labels, act_length, label_length)
loss.backward()

def _test_costs_and_gradients(
self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2
):
logits_shape = data["logits"].shape
costs, gradients = compute_with_pytorch_transducer(data=data)
np.testing.assert_allclose(costs, ref_costs, atol=atol, rtol=rtol)
self.assertEqual(logits_shape, gradients.shape)
if not np.allclose(gradients, ref_gradients, atol=atol, rtol=rtol):
for b in range(len(gradients)):
T = data["src_lengths"][b]
U = data["tgt_lengths"][b]
for t in range(gradients.shape[1]):
for u in range(gradients.shape[2]):
np.testing.assert_allclose(
gradients[b, t, u],
ref_gradients[b, t, u],
atol=atol,
rtol=rtol,
err_msg=f"failed on b={b}, t={t}/T={T}, u={u}/U={U}",
)

def test_costs_and_gradients_B2_T4_U3_D3_fp32(self):
data, ref_costs, ref_gradients = get_numpy_data_B2_T4_U3_D3(dtype=np.float32)
data = numpy_to_torch(data=data, device=self.device, requires_grad=True)
self._test_costs_and_gradients(
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
)


@common_utils.skipIfNoExtension
class CPUTransducerTester(TransducerTester, common_utils.PytorchTestCase):
device = "cpu"
2 changes: 2 additions & 0 deletions third_party/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,5 @@ ExternalProject_Add(libsox
# See https://github.com/pytorch/audio/pull/1026
CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/build_codec_helper.sh ${CMAKE_CURRENT_SOURCE_DIR}/src/libsox/configure ${COMMON_ARGS} --with-lame --with-flac --with-mad --with-oggvorbis --without-alsa --without-coreaudio --without-png --without-oss --without-sndfile --with-opus --with-amrwb --with-amrnb --disable-openmp
)

add_subdirectory(transducer)
48 changes: 48 additions & 0 deletions third_party/transducer/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
MESSAGE("path to cmake current source dir: ${CMAKE_CURRENT_SOURCE_DIR}")
IF(APPLE)
CMAKE_MINIMUM_REQUIRED(VERSION 3.4)
ELSE()
CMAKE_MINIMUM_REQUIRED(VERSION 2.8)
ENDIF()

PROJECT(rnnt_release)

IF(APPLE)
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -O2")
ADD_DEFINITIONS(-DAPPLE)
ELSE()
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2")
ENDIF()

INCLUDE_DIRECTORIES(submodule/include)

SET(CMAKE_POSITION_INDEPENDENT_CODE ON)

ADD_DEFINITIONS(-DRNNT_DISABLE_OMP)

IF(APPLE)
EXEC_PROGRAM(uname ARGS -v OUTPUT_VARIABLE DARWIN_VERSION)
STRING(REGEX MATCH "[0-9]+" DARWIN_VERSION ${DARWIN_VERSION})
MESSAGE(STATUS "DARWIN_VERSION=${DARWIN_VERSION}")

# for el capitain have to use rpath
IF(DARWIN_VERSION LESS 15)
SET(CMAKE_SKIP_RPATH TRUE)
ENDIF()

ELSE()
# always skip for linux
SET(CMAKE_SKIP_RPATH TRUE)
ENDIF()

IF(NOT APPLE)
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -O2")
ENDIF()

ADD_LIBRARY(warprnnt STATIC submodule/src/rnnt_entrypoint.cpp)

INSTALL(TARGETS warprnnt
LIBRARY DESTINATION "lib"
ARCHIVE DESTINATION "archives")

INSTALL(FILES submodule/include/rnnt.h DESTINATION "submodule/include")
1 change: 1 addition & 0 deletions third_party/transducer/submodule
Submodule submodule added at f54657
Loading

0 comments on commit ca66151

Please sign in to comment.