Skip to content

Commit

Permalink
update tests and TimeStretch
Browse files Browse the repository at this point in the history
  • Loading branch information
anjali411 committed Jul 1, 2020
1 parent 2f6494b commit 6795dc4
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 2 deletions.
1 change: 1 addition & 0 deletions test/test_batch_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def test_batch_TimeStretch(self):
rate = 2

complex_specgrams = torch.stft(waveform, **kwargs)
complex_specgrams = torch.view_as_complex(complex_specgrams)

# Single then transform then batch
expected = torchaudio.transforms.TimeStretch(
Expand Down
2 changes: 1 addition & 1 deletion test/torchscript_consistency_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def test_TimeStretch(self):
n_freq = 400
hop_length = 512
fixed_rate = 1.3
tensor = torch.rand((10, 2, n_freq, 10, 2))
tensor = torch.rand((10, 2, n_freq, 10), dtype=torch.cdouble)
self._assert_consistency(
T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
tensor,
Expand Down
2 changes: 1 addition & 1 deletion torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] =
Returns:
Tensor: Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2).
"""
assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (..., complex=2)"
# assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (..., complex=2)"

if overriding_rate is None:
rate = self.fixed_rate
Expand Down

0 comments on commit 6795dc4

Please sign in to comment.