Skip to content

Commit

Permalink
Add torchscript support to RNNT Loss (#1507)
Browse files Browse the repository at this point in the history
  • Loading branch information
Caroline Chen authored May 19, 2021
1 parent 079b3f5 commit af7eb4d
Show file tree
Hide file tree
Showing 9 changed files with 234 additions and 107 deletions.
10 changes: 10 additions & 0 deletions test/torchaudio_unittest/rnnt/torchscript_consistency_cpu_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch

from torchaudio_unittest.common_utils import PytorchTestCase
from .utils import skipIfNoTransducer
from .torchscript_consistency_impl import RNNTLossTorchscript


@skipIfNoTransducer
class TestRNNTLoss(RNNTLossTorchscript, PytorchTestCase):
device = torch.device('cpu')
11 changes: 11 additions & 0 deletions test/torchaudio_unittest/rnnt/torchscript_consistency_cuda_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import torch

from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .utils import skipIfNoTransducer
from .torchscript_consistency_impl import RNNTLossTorchscript


@skipIfNoTransducer
@skipIfNoCuda
class TestRNNTLoss(RNNTLossTorchscript, PytorchTestCase):
device = torch.device('cuda')
70 changes: 70 additions & 0 deletions test/torchaudio_unittest/rnnt/torchscript_consistency_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torch
from torchaudio_unittest.common_utils import TempDirMixin, TestBaseMixin
from torchaudio.prototype.rnnt_loss import RNNTLoss, rnnt_loss


class RNNTLossTorchscript(TempDirMixin, TestBaseMixin):
"""Implements test for RNNT Loss that are performed for different devices"""
def _assert_consistency(self, func, tensor, shape_only=False):
tensor = tensor.to(device=self.device, dtype=self.dtype)

path = self.get_temp_path('func.zip')
torch.jit.script(func).save(path)
ts_func = torch.jit.load(path)

torch.random.manual_seed(40)
input_tensor = tensor.clone().detach().requires_grad_(True)
output = func(input_tensor)

torch.random.manual_seed(40)
input_tensor = tensor.clone().detach().requires_grad_(True)
ts_output = ts_func(input_tensor)

self.assertEqual(ts_output, output)

def test_rnnt_loss(self):
def func(
logits,
):
targets = torch.tensor([[1, 2]], device=logits.device, dtype=torch.int32)
logit_lengths = torch.tensor([2], device=logits.device, dtype=torch.int32)
target_lengths = torch.tensor([2], device=logits.device, dtype=torch.int32)
return rnnt_loss(logits, targets, logit_lengths, target_lengths)

logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.6, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.8, 0.1]],
[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.1, 0.1],
[0.7, 0.1, 0.2, 0.1, 0.1]]]])

self._assert_consistency(func, logits)

def test_RNNTLoss(self):
func = RNNTLoss()

logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.6, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.8, 0.1]],
[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.1, 0.1],
[0.7, 0.1, 0.2, 0.1, 0.1]]]])
targets = torch.tensor([[1, 2]], device=self.device, dtype=torch.int32)
logit_lengths = torch.tensor([2], device=self.device, dtype=torch.int32)
target_lengths = torch.tensor([2], device=self.device, dtype=torch.int32)

tensor = logits.to(device=self.device, dtype=self.dtype)

path = self.get_temp_path('func.zip')
torch.jit.script(func).save(path)
ts_func = torch.jit.load(path)

torch.random.manual_seed(40)
input_tensor = tensor.clone().detach().requires_grad_(True)
output = func(input_tensor, targets, logit_lengths, target_lengths)

torch.random.manual_seed(40)
input_tensor = tensor.clone().detach().requires_grad_(True)
ts_output = ts_func(input_tensor, targets, logit_lengths, target_lengths)

self.assertEqual(ts_output, output)
8 changes: 4 additions & 4 deletions test/torchaudio_unittest/rnnt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,10 +405,10 @@ def get_numpy_random_data(


def numpy_to_torch(data, device, requires_grad=True):
logits = torch.from_numpy(data["logits"])
targets = torch.from_numpy(data["targets"])
logit_lengths = torch.from_numpy(data["logit_lengths"])
target_lengths = torch.from_numpy(data["target_lengths"])
logits = torch.from_numpy(data["logits"]).to(device=device)
targets = torch.from_numpy(data["targets"]).to(device=device)
logit_lengths = torch.from_numpy(data["logit_lengths"]).to(device=device)
target_lengths = torch.from_numpy(data["target_lengths"]).to(device=device)

if "nbest_wers" in data:
data["nbest_wers"] = torch.from_numpy(data["nbest_wers"]).to(device=device)
Expand Down
1 change: 1 addition & 0 deletions torchaudio/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ if(BUILD_TRANSDUCER)
rnnt/compute_alphas.cpp
rnnt/compute_betas.cpp
rnnt/compute.cpp
rnnt/autograd.cpp
)

if (USE_CUDA)
Expand Down
74 changes: 74 additions & 0 deletions torchaudio/csrc/rnnt/autograd.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/compute.h>

namespace torchaudio {
namespace rnnt {

class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
public:
static torch::autograd::tensor_list forward(
torch::autograd::AutogradContext* ctx,
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
int64_t blank,
double clamp,
bool fused_log_smax = true,
bool reuse_logits_for_grads = true) {
at::AutoNonVariableTypeMode g;
torch::Tensor undef;
auto result = rnnt_loss(
logits,
targets,
src_lengths,
tgt_lengths,
blank,
clamp,
fused_log_smax,
reuse_logits_for_grads);
auto costs = std::get<0>(result);
auto grads = std::get<1>(result).value_or(undef);
ctx->save_for_backward({grads});
return {costs, grads};
}

static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto grad = saved[0];
auto grad_out = grad_outputs[0].view({-1, 1, 1, 1});
auto result = grad * grad_out;
torch::Tensor undef;
return {result, undef, undef, undef, undef, undef, undef, undef};
}
};

std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
int64_t blank,
double clamp,
bool fused_log_smax = true,
bool reuse_logits_for_grads = true) {
auto results = RNNTLossFunction::apply(
logits,
targets,
src_lengths,
tgt_lengths,
blank,
clamp,
fused_log_smax,
reuse_logits_for_grads);
return std::make_tuple(results[0], results[1]);
}

TORCH_LIBRARY_IMPL(torchaudio, Autograd, m) {
m.impl("rnnt_loss", rnnt_loss_autograd);
}

} // namespace rnnt
} // namespace torchaudio
24 changes: 24 additions & 0 deletions torchaudio/csrc/rnnt/compute.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,28 @@
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/compute.h>

std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
int64_t blank,
double clamp,
bool fused_log_smax = true,
bool reuse_logits_for_grads = true) {
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("torchaudio::rnnt_loss", "")
.typed<decltype(rnnt_loss)>();
return op.call(
logits,
targets,
src_lengths,
tgt_lengths,
blank,
clamp,
fused_log_smax,
reuse_logits_for_grads);
}

TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
Expand Down
13 changes: 13 additions & 0 deletions torchaudio/csrc/rnnt/compute.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#pragma once

#include <torch/script.h>

std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
int64_t blank,
double clamp,
bool fused_log_smax,
bool reuse_logits_for_grads);
Loading

0 comments on commit af7eb4d

Please sign in to comment.