Skip to content

Commit

Permalink
replace numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentqb committed Jan 5, 2021
1 parent a490d87 commit 1aed0d2
Showing 1 changed file with 32 additions and 63 deletions.
95 changes: 32 additions & 63 deletions test/torchaudio_unittest/transducer_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import numpy as np
import torch
from torchaudio.prototype.transducer import RNNTLoss

from torchaudio_unittest import common_utils


def get_basic_data(device):
def get_data_basic(device):
# Example provided
# in 6f73a2513dc784c59eec153a45f40bc528355b18
# of https://github.com/HawkAaron/warp-transducer
Expand Down Expand Up @@ -41,8 +40,10 @@ def get_basic_data(device):
return acts, labels, act_length, label_length


def get_numpy_data_B2_T4_U3_D3(dtype=np.float32):
logits = np.array(
def get_data_B2_T4_U3_D3(dtype=torch.float32, device="cpu"):
# Test from D21322854

logits = torch.tensor(
[
0.065357,
0.787530,
Expand Down Expand Up @@ -120,15 +121,15 @@ def get_numpy_data_B2_T4_U3_D3(dtype=np.float32):
dtype=dtype,
).reshape(2, 4, 3, 3)

targets = np.array([[1, 2], [1, 1]], dtype=np.int32)
src_lengths = np.array([4, 4], dtype=np.int32)
tgt_lengths = np.array([2, 2], dtype=np.int32)
targets = torch.tensor([[1, 2], [1, 1]], dtype=torch.int32)
src_lengths = torch.tensor([4, 4], dtype=torch.int32)
tgt_lengths = torch.tensor([2, 2], dtype=torch.int32)

blank = 0

ref_costs = np.array([4.2806528590890736, 3.9384369822503591], dtype=dtype)
ref_costs = torch.tensor([4.2806528590890736, 3.9384369822503591], dtype=dtype)

ref_gradients = np.array(
ref_gradients = torch.tensor(
[
-0.186844,
-0.062555,
Expand Down Expand Up @@ -206,6 +207,14 @@ def get_numpy_data_B2_T4_U3_D3(dtype=np.float32):
dtype=dtype,
).reshape(2, 4, 3, 3)

logits.requires_grad_(True)
logits = logits.to(device)

def grad_hook(grad):
logits.saved_grad = grad.clone()

logits.register_hook(grad_hook)

data = {
"logits": logits,
"targets": targets,
Expand All @@ -217,49 +226,25 @@ def get_numpy_data_B2_T4_U3_D3(dtype=np.float32):
return data, ref_costs, ref_gradients


def numpy_to_torch(data, device, requires_grad=True):

logits = torch.from_numpy(data["logits"])
targets = torch.from_numpy(data["targets"])
src_lengths = torch.from_numpy(data["src_lengths"])
tgt_lengths = torch.from_numpy(data["tgt_lengths"])

logits.requires_grad_(requires_grad)

logits = logits.to(device)

def grad_hook(grad):
logits.saved_grad = grad.clone()

logits.register_hook(grad_hook)

data["logits"] = logits
data["src_lengths"] = src_lengths
data["tgt_lengths"] = tgt_lengths
data["targets"] = targets

return data


def compute_with_pytorch_transducer(data):
costs = RNNTLoss(blank=data["blank"], reduction="none")(
acts=data["logits_sparse"] if "logits_sparse" in data else data["logits"],
acts=data["logits"],
labels=data["targets"],
act_lens=data["src_lengths"],
label_lens=data["tgt_lengths"],
)

loss = torch.sum(costs)
loss.backward()
costs = costs.cpu().data.numpy()
gradients = data["logits"].saved_grad.cpu().data.numpy()
costs = costs.cpu()
gradients = data["logits"].saved_grad.cpu()
return costs, gradients


class TransducerTester:
def test_basic_fp16_error(self):
rnnt_loss = RNNTLoss()
acts, labels, act_length, label_length = get_basic_data(self.device)
acts, labels, act_length, label_length = get_data_basic(self.device)
acts = acts.to(torch.float16)
# RuntimeError raised by log_softmax before reaching transducer's bindings
self.assertRaises(
Expand All @@ -268,38 +253,22 @@ def test_basic_fp16_error(self):

def test_basic_backward(self):
rnnt_loss = RNNTLoss()
acts, labels, act_length, label_length = get_basic_data(self.device)
acts, labels, act_length, label_length = get_data_basic(self.device)
loss = rnnt_loss(acts, labels, act_length, label_length)
loss.backward()

def _test_costs_and_gradients(
self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2
):
def test_costs_and_gradients_B2_T4_U3_D3_fp32(self):

data, ref_costs, ref_gradients = get_data_B2_T4_U3_D3(
dtype=torch.float32, device=self.device
)
logits_shape = data["logits"].shape
costs, gradients = compute_with_pytorch_transducer(data=data)
np.testing.assert_allclose(costs, ref_costs, atol=atol, rtol=rtol)
self.assertEqual(logits_shape, gradients.shape)
if not np.allclose(gradients, ref_gradients, atol=atol, rtol=rtol):
for b in range(len(gradients)):
T = data["src_lengths"][b]
U = data["tgt_lengths"][b]
for t in range(gradients.shape[1]):
for u in range(gradients.shape[2]):
np.testing.assert_allclose(
gradients[b, t, u],
ref_gradients[b, t, u],
atol=atol,
rtol=rtol,
err_msg=f"failed on b={b}, t={t}/T={T}, u={u}/U={U}",
)

def test_costs_and_gradients_B2_T4_U3_D3_fp32(self):
# Test from D21322854
data, ref_costs, ref_gradients = get_numpy_data_B2_T4_U3_D3(dtype=np.float32)
data = numpy_to_torch(data=data, device=self.device, requires_grad=True)
self._test_costs_and_gradients(
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
)
atol, rtol = 1e-6, 1e-2
self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol)
self.assertEqual(logits_shape, gradients.shape)
self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol)


@common_utils.skipIfNoExtension
Expand Down

0 comments on commit 1aed0d2

Please sign in to comment.