Skip to content

Commit

Permalink
Add smoke test for sox_io fileobj (#1165)
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok authored Jan 29, 2021
1 parent b152ee6 commit 5085aeb
Showing 1 changed file with 68 additions and 0 deletions.
68 changes: 68 additions & 0 deletions test/torchaudio_unittest/sox_io_backend/smoke_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import itertools
import unittest

Expand Down Expand Up @@ -85,3 +86,70 @@ def test_vorbis(self, sample_rate, num_channels, quality_level):
def test_flac(self, sample_rate, num_channels, compression_level):
"""Run smoke test on flac format"""
self.run_smoke_test('flac', sample_rate, num_channels, compression=compression_level)


@skipIfNoExtension
class SmokeTestFileObj(TorchaudioTestCase):
"""Run smoke test on various audio format
The purpose of this test suite is to verify that sox_io_backend functionalities do not exhibit
abnormal behaviors.
This test suite should be able to run without any additional tools (such as sox command),
however without such tools, the correctness of each function cannot be verified.
"""
def run_smoke_test(self, ext, sample_rate, num_channels, *, compression=None, dtype='float32'):
duration = 1
num_frames = sample_rate * duration
original = get_wav_data(dtype, num_channels, normalize=False, num_frames=num_frames)

fileobj = io.BytesIO()
# 1. run save
sox_io_backend.save(fileobj, original, sample_rate, compression=compression, format=ext)
# 2. run info
fileobj.seek(0)
info = sox_io_backend.info(fileobj, format=ext)
assert info.sample_rate == sample_rate
assert info.num_channels == num_channels
# 3. run load
fileobj.seek(0)
loaded, sr = sox_io_backend.load(fileobj, normalize=False, format=ext)
assert sr == sample_rate
assert loaded.shape[0] == num_channels

@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=name_func)
def test_wav(self, dtype, sample_rate, num_channels):
"""Run smoke test on wav format"""
self.run_smoke_test('wav', sample_rate, num_channels, dtype=dtype)

@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320],
)))
@skipIfNoMP3
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""Run smoke test on mp3 format"""
self.run_smoke_test('mp3', sample_rate, num_channels, compression=bit_rate)

@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)))
def test_vorbis(self, sample_rate, num_channels, quality_level):
"""Run smoke test on vorbis format"""
self.run_smoke_test('vorbis', sample_rate, num_channels, compression=quality_level)

@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)), name_func=name_func)
def test_flac(self, sample_rate, num_channels, compression_level):
"""Run smoke test on flac format"""
self.run_smoke_test('flac', sample_rate, num_channels, compression=compression_level)

0 comments on commit 5085aeb

Please sign in to comment.