-
Notifications
You must be signed in to change notification settings - Fork 664
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add torchscript support to RNNT Loss (#1507)
- Loading branch information
Caroline Chen
authored
May 19, 2021
1 parent
079b3f5
commit af7eb4d
Showing
9 changed files
with
234 additions
and
107 deletions.
There are no files selected for viewing
10 changes: 10 additions & 0 deletions
10
test/torchaudio_unittest/rnnt/torchscript_consistency_cpu_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import torch | ||
|
||
from torchaudio_unittest.common_utils import PytorchTestCase | ||
from .utils import skipIfNoTransducer | ||
from .torchscript_consistency_impl import RNNTLossTorchscript | ||
|
||
|
||
@skipIfNoTransducer | ||
class TestRNNTLoss(RNNTLossTorchscript, PytorchTestCase): | ||
device = torch.device('cpu') |
11 changes: 11 additions & 0 deletions
11
test/torchaudio_unittest/rnnt/torchscript_consistency_cuda_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import torch | ||
|
||
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda | ||
from .utils import skipIfNoTransducer | ||
from .torchscript_consistency_impl import RNNTLossTorchscript | ||
|
||
|
||
@skipIfNoTransducer | ||
@skipIfNoCuda | ||
class TestRNNTLoss(RNNTLossTorchscript, PytorchTestCase): | ||
device = torch.device('cuda') |
70 changes: 70 additions & 0 deletions
70
test/torchaudio_unittest/rnnt/torchscript_consistency_impl.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import torch | ||
from torchaudio_unittest.common_utils import TempDirMixin, TestBaseMixin | ||
from torchaudio.prototype.rnnt_loss import RNNTLoss, rnnt_loss | ||
|
||
|
||
class RNNTLossTorchscript(TempDirMixin, TestBaseMixin): | ||
"""Implements test for RNNT Loss that are performed for different devices""" | ||
def _assert_consistency(self, func, tensor, shape_only=False): | ||
tensor = tensor.to(device=self.device, dtype=self.dtype) | ||
|
||
path = self.get_temp_path('func.zip') | ||
torch.jit.script(func).save(path) | ||
ts_func = torch.jit.load(path) | ||
|
||
torch.random.manual_seed(40) | ||
input_tensor = tensor.clone().detach().requires_grad_(True) | ||
output = func(input_tensor) | ||
|
||
torch.random.manual_seed(40) | ||
input_tensor = tensor.clone().detach().requires_grad_(True) | ||
ts_output = ts_func(input_tensor) | ||
|
||
self.assertEqual(ts_output, output) | ||
|
||
def test_rnnt_loss(self): | ||
def func( | ||
logits, | ||
): | ||
targets = torch.tensor([[1, 2]], device=logits.device, dtype=torch.int32) | ||
logit_lengths = torch.tensor([2], device=logits.device, dtype=torch.int32) | ||
target_lengths = torch.tensor([2], device=logits.device, dtype=torch.int32) | ||
return rnnt_loss(logits, targets, logit_lengths, target_lengths) | ||
|
||
logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1], | ||
[0.1, 0.1, 0.6, 0.1, 0.1], | ||
[0.1, 0.1, 0.2, 0.8, 0.1]], | ||
[[0.1, 0.6, 0.1, 0.1, 0.1], | ||
[0.1, 0.1, 0.2, 0.1, 0.1], | ||
[0.7, 0.1, 0.2, 0.1, 0.1]]]]) | ||
|
||
self._assert_consistency(func, logits) | ||
|
||
def test_RNNTLoss(self): | ||
func = RNNTLoss() | ||
|
||
logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1], | ||
[0.1, 0.1, 0.6, 0.1, 0.1], | ||
[0.1, 0.1, 0.2, 0.8, 0.1]], | ||
[[0.1, 0.6, 0.1, 0.1, 0.1], | ||
[0.1, 0.1, 0.2, 0.1, 0.1], | ||
[0.7, 0.1, 0.2, 0.1, 0.1]]]]) | ||
targets = torch.tensor([[1, 2]], device=self.device, dtype=torch.int32) | ||
logit_lengths = torch.tensor([2], device=self.device, dtype=torch.int32) | ||
target_lengths = torch.tensor([2], device=self.device, dtype=torch.int32) | ||
|
||
tensor = logits.to(device=self.device, dtype=self.dtype) | ||
|
||
path = self.get_temp_path('func.zip') | ||
torch.jit.script(func).save(path) | ||
ts_func = torch.jit.load(path) | ||
|
||
torch.random.manual_seed(40) | ||
input_tensor = tensor.clone().detach().requires_grad_(True) | ||
output = func(input_tensor, targets, logit_lengths, target_lengths) | ||
|
||
torch.random.manual_seed(40) | ||
input_tensor = tensor.clone().detach().requires_grad_(True) | ||
ts_output = ts_func(input_tensor, targets, logit_lengths, target_lengths) | ||
|
||
self.assertEqual(ts_output, output) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
#include <torch/script.h> | ||
#include <torchaudio/csrc/rnnt/compute.h> | ||
|
||
namespace torchaudio { | ||
namespace rnnt { | ||
|
||
class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> { | ||
public: | ||
static torch::autograd::tensor_list forward( | ||
torch::autograd::AutogradContext* ctx, | ||
torch::Tensor& logits, | ||
const torch::Tensor& targets, | ||
const torch::Tensor& src_lengths, | ||
const torch::Tensor& tgt_lengths, | ||
int64_t blank, | ||
double clamp, | ||
bool fused_log_smax = true, | ||
bool reuse_logits_for_grads = true) { | ||
at::AutoNonVariableTypeMode g; | ||
torch::Tensor undef; | ||
auto result = rnnt_loss( | ||
logits, | ||
targets, | ||
src_lengths, | ||
tgt_lengths, | ||
blank, | ||
clamp, | ||
fused_log_smax, | ||
reuse_logits_for_grads); | ||
auto costs = std::get<0>(result); | ||
auto grads = std::get<1>(result).value_or(undef); | ||
ctx->save_for_backward({grads}); | ||
return {costs, grads}; | ||
} | ||
|
||
static torch::autograd::tensor_list backward( | ||
torch::autograd::AutogradContext* ctx, | ||
torch::autograd::tensor_list grad_outputs) { | ||
auto saved = ctx->get_saved_variables(); | ||
auto grad = saved[0]; | ||
auto grad_out = grad_outputs[0].view({-1, 1, 1, 1}); | ||
auto result = grad * grad_out; | ||
torch::Tensor undef; | ||
return {result, undef, undef, undef, undef, undef, undef, undef}; | ||
} | ||
}; | ||
|
||
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd( | ||
torch::Tensor& logits, | ||
const torch::Tensor& targets, | ||
const torch::Tensor& src_lengths, | ||
const torch::Tensor& tgt_lengths, | ||
int64_t blank, | ||
double clamp, | ||
bool fused_log_smax = true, | ||
bool reuse_logits_for_grads = true) { | ||
auto results = RNNTLossFunction::apply( | ||
logits, | ||
targets, | ||
src_lengths, | ||
tgt_lengths, | ||
blank, | ||
clamp, | ||
fused_log_smax, | ||
reuse_logits_for_grads); | ||
return std::make_tuple(results[0], results[1]); | ||
} | ||
|
||
TORCH_LIBRARY_IMPL(torchaudio, Autograd, m) { | ||
m.impl("rnnt_loss", rnnt_loss_autograd); | ||
} | ||
|
||
} // namespace rnnt | ||
} // namespace torchaudio |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#pragma once | ||
|
||
#include <torch/script.h> | ||
|
||
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss( | ||
torch::Tensor& logits, | ||
const torch::Tensor& targets, | ||
const torch::Tensor& src_lengths, | ||
const torch::Tensor& tgt_lengths, | ||
int64_t blank, | ||
double clamp, | ||
bool fused_log_smax, | ||
bool reuse_logits_for_grads); |
Oops, something went wrong.