Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Nov 3, 2021
1 parent 30d0e34 commit 8d8feab
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,11 @@ def _assert_consistency(self, func, tensor, shape_only=False):
output = output.shape
self.assertEqual(ts_output, output)

def _assert_consistency_complex(self, func, tensor, test_pseudo_complex=False):
def _assert_consistency_complex(self, func, tensor):
assert tensor.is_complex()
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
ts_func = torch_script(func)

if test_pseudo_complex:
tensor = torch.view_as_real(tensor)

torch.random.manual_seed(40)
output = func(tensor)

Expand Down Expand Up @@ -672,16 +669,14 @@ def func_beta(tensor):

def test_phase_vocoder(self):
def func(tensor):
is_complex = tensor.is_complex()

n_freq = tensor.size(-2 if is_complex else -3)
n_freq = tensor.size(-2)
rate = 0.5
hop_length = 256
phase_advance = torch.linspace(
0,
3.14 * hop_length,
n_freq,
dtype=(torch.real(tensor) if is_complex else tensor).dtype,
dtype=torch.real(tensor).dtype,
device=tensor.device,
)[..., None]
return F.phase_vocoder(tensor, rate, phase_advance)
Expand Down
6 changes: 3 additions & 3 deletions test/torchaudio_unittest/transforms/batch_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,13 @@ def test_batch_lfcc(self):
def test_batch_TimeStretch(self):
rate = 2
num_freq = 1025
num_frames = 400
batch = 3

spec = torch.randn(batch, num_freq, num_frames, dtype=torch.complex64)
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=batch)
spec = common_utils.get_spectrogram(tensor, n_fft=num_freq)
transform = T.TimeStretch(
fixed_rate=rate,
n_freq=num_freq,
n_freq=num_freq // 2 + 1,
hop_length=512
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,13 @@ def _assert_consistency(self, transform, tensor, *args):
ts_output = ts_transform(tensor, *args)
self.assertEqual(ts_output, output)

def _assert_consistency_complex(self, transform, tensor, test_pseudo_complex=False, *args):
def _assert_consistency_complex(self, transform, tensor, *args):
assert tensor.is_complex()
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
transform = transform.to(device=self.device, dtype=self.dtype)

ts_transform = torch_script(transform)

if test_pseudo_complex:
tensor = torch.view_as_real(tensor)
output = transform(tensor, *args)
ts_output = ts_transform(tensor, *args)
self.assertEqual(ts_output, output)
Expand Down Expand Up @@ -127,14 +125,20 @@ def test_SpectralCentroid(self):
self._assert_consistency(T.SpectralCentroid(sample_rate=sample_rate), waveform)

def test_TimeStretch(self):
n_freq = 400
n_fft = 1025
n_freq = n_fft // 2 + 1
hop_length = 512
fixed_rate = 1.3
tensor = torch.rand((10, 2, n_freq, 10), dtype=torch.cfloat)
batch = 10
num_channels = 2

waveform = common_utils.get_whitenoise(sample_rate=8000, n_channels=batch * num_channels)
tensor = common_utils.get_spectrogram(waveform, n_fft=n_fft)
tensor = tensor.reshape(batch, num_channels, n_freq, -1)
self._assert_consistency_complex(
T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
tensor,
False,
)

def test_PitchShift(self):
Expand All @@ -157,7 +161,7 @@ def test_PSD_with_mask(self):
spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
spectrogram = spectrogram.to(self.device)
mask = torch.rand(spectrogram.shape[-2:], device=self.device)
self._assert_consistency_complex(T.PSD(), spectrogram, False, mask)
self._assert_consistency_complex(T.PSD(), spectrogram, mask)


class TransformsFloat32Only(TestBaseMixin):
Expand Down Expand Up @@ -193,5 +197,5 @@ def test_MVDR(self, solution, online):
mask_n = torch.rand(spectrogram.shape[-2:], device=self.device)
self._assert_consistency_complex(
T.MVDR(solution=solution, online=online),
spectrogram, False, mask_s, mask_n
spectrogram, mask_s, mask_n
)

0 comments on commit 8d8feab

Please sign in to comment.