Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BREAKING: add support for targets as discussed in #3 #123

Merged
merged 15 commits into from
Apr 1, 2022
Merged
28 changes: 13 additions & 15 deletions tests/test_background_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class TestAddBackgroundNoise(unittest.TestCase):
def setUp(self):
self.sample_rate = 16000
self.batch_size = 16
self.empty_input_audio = torch.empty(0)
self.empty_input_audio = torch.empty(0, 1, 16000)
# TODO: use utils.io.Audio
self.input_audio = (
torch.from_numpy(
Expand All @@ -47,15 +47,15 @@ def setUp(self):
def test_background_noise_no_guarantee_with_single_tensor(self):
mixed_input = self.bg_noise_transform_no_guarantee(
self.input_audio, self.sample_rate
)
).samples
self.assertTrue(torch.equal(mixed_input, self.input_audio))
self.assertEqual(mixed_input.size(0), self.input_audio.size(0))

def test_background_noise_no_guarantee_with_empty_tensor(self):
with self.assertWarns(UserWarning) as warning_context_manager:
mixed_input = self.bg_noise_transform_no_guarantee(
self.empty_input_audio, self.sample_rate
)
).samples

self.assertIn(
"An empty samples tensor was passed", str(warning_context_manager.warning)
Expand All @@ -69,7 +69,7 @@ def test_background_noise_guaranteed_with_zero_length_samples(self):
with self.assertWarns(UserWarning) as warning_context_manager:
mixed_input = self.bg_noise_transform_guaranteed(
self.empty_input_audio, self.sample_rate
)
).samples

self.assertIn(
"An empty samples tensor was passed", str(warning_context_manager.warning)
Expand All @@ -81,7 +81,7 @@ def test_background_noise_guaranteed_with_zero_length_samples(self):
def test_background_noise_guaranteed_with_single_tensor(self):
mixed_input = self.bg_noise_transform_guaranteed(
self.input_audio, self.sample_rate
)
).samples
self.assertFalse(torch.equal(mixed_input, self.input_audio))
self.assertEqual(mixed_input.size(0), self.input_audio.size(0))
self.assertEqual(mixed_input.size(1), self.input_audio.size(1))
Expand All @@ -90,15 +90,15 @@ def test_background_noise_guaranteed_with_batched_tensor(self):
random.seed(42)
mixed_inputs = self.bg_noise_transform_guaranteed(
self.input_audios, self.sample_rate
)
).samples
self.assertFalse(torch.equal(mixed_inputs, self.input_audios))
self.assertEqual(mixed_inputs.size(0), self.input_audios.size(0))
self.assertEqual(mixed_inputs.size(1), self.input_audios.size(1))

def test_background_short_noise_guaranteed_with_batched_tensor(self):
mixed_input = self.bg_short_noise_transform_guaranteed(
self.input_audio, self.sample_rate
)
).samples
self.assertFalse(torch.equal(mixed_input, self.input_audio))
self.assertEqual(mixed_input.size(0), self.input_audio.size(0))
self.assertEqual(mixed_input.size(1), self.input_audio.size(1))
Expand All @@ -108,7 +108,7 @@ def test_background_short_noise_guaranteed_with_batched_cuda_tensor(self):
input_audio_cuda = self.input_audio.cuda()
mixed_input = self.bg_short_noise_transform_guaranteed(
input_audio_cuda, self.sample_rate
)
).samples
assert not torch.equal(mixed_input, input_audio_cuda)
assert mixed_input.shape == input_audio_cuda.shape
assert mixed_input.dtype == input_audio_cuda.dtype
Expand All @@ -120,7 +120,7 @@ def test_varying_snr_within_batch(self):
augment = AddBackgroundNoise(
self.bg_path, min_snr_in_db=min_snr_in_db, max_snr_in_db=max_snr_in_db, p=1.0
)
augmented_audios = augment(self.input_audios, self.sample_rate)
augmented_audios = augment(self.input_audios, self.sample_rate).samples

self.assertEqual(tuple(augmented_audios.shape), tuple(self.input_audios.shape))
self.assertFalse(torch.equal(augmented_audios, self.input_audios))
Expand Down Expand Up @@ -150,7 +150,7 @@ def test_min_equals_max(self):
augment = AddBackgroundNoise(
self.bg_path, min_snr_in_db=desired_snr, max_snr_in_db=desired_snr, p=1.0
)
augmented_audios = augment(self.input_audios, self.sample_rate)
augmented_audios = augment(self.input_audios, self.sample_rate).samples

self.assertEqual(tuple(augmented_audios.shape), tuple(self.input_audios.shape))
self.assertFalse(torch.equal(augmented_audios, self.input_audios))
Expand All @@ -171,11 +171,9 @@ def test_compatibility_of_resampled_length(self):
input_sample_rate = random.randint(1000, 5000)
bg_sample_rate = random.randint(1000, 5000)

noise = np.random.uniform(
low=-0.2,
high=0.2,
size=(bg_length,),
).astype(np.float32)
noise = np.random.uniform(low=-0.2, high=0.2, size=(bg_length,),).astype(
np.float32
)
tmp_dir = os.path.join(tempfile.gettempdir(), str(uuid.uuid4()))
try:
os.makedirs(tmp_dir)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_band_pass_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ def test_band_pass_filter(self):
for _ in range(20):
processed_samples = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).numpy()
).samples.numpy()
self.assertEqual(processed_samples.shape, samples.shape)
self.assertEqual(processed_samples.dtype, np.float32)
2 changes: 1 addition & 1 deletion tests/test_band_stop_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ def test_band_reject_filter(self):
augment = BandStopFilter(p=1.0)
processed_samples = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).numpy()
).samples.numpy()
self.assertEqual(processed_samples.shape, samples.shape)
self.assertEqual(processed_samples.dtype, np.float32)
12 changes: 6 additions & 6 deletions tests/test_colored_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def setUp(self):
self.sample_rate = 16000
self.audio = Audio(sample_rate=self.sample_rate)
self.batch_size = 16
self.empty_input_audio = torch.empty(0)
self.empty_input_audio = torch.empty(0, 1, 16000)

self.input_audio = self.audio(
TEST_FIXTURES_DIR / "acoustic_guitar_0.wav"
Expand All @@ -26,15 +26,15 @@ def setUp(self):
def test_colored_noise_no_guarantee_with_single_tensor(self):
mixed_input = self.cl_noise_transform_no_guarantee(
self.input_audio, self.sample_rate
)
).samples
self.assertTrue(torch.equal(mixed_input, self.input_audio))
self.assertEqual(mixed_input.size(0), self.input_audio.size(0))

def test_background_noise_no_guarantee_with_empty_tensor(self):
with self.assertWarns(UserWarning) as warning_context_manager:
mixed_input = self.cl_noise_transform_no_guarantee(
self.empty_input_audio, self.sample_rate
)
).samples

self.assertIn(
"An empty samples tensor was passed", str(warning_context_manager.warning)
Expand All @@ -48,7 +48,7 @@ def test_colored_noise_guaranteed_with_zero_length_samples(self):
with self.assertWarns(UserWarning) as warning_context_manager:
mixed_input = self.cl_noise_transform_guaranteed(
self.empty_input_audio, self.sample_rate
)
).samples

self.assertIn(
"An empty samples tensor was passed", str(warning_context_manager.warning)
Expand All @@ -60,7 +60,7 @@ def test_colored_noise_guaranteed_with_zero_length_samples(self):
def test_colored_noise_guaranteed_with_single_tensor(self):
mixed_input = self.cl_noise_transform_guaranteed(
self.input_audio, self.sample_rate
)
).samples
self.assertFalse(torch.equal(mixed_input, self.input_audio))
self.assertEqual(mixed_input.size(0), self.input_audio.size(0))
self.assertEqual(mixed_input.size(1), self.input_audio.size(1))
Expand All @@ -69,7 +69,7 @@ def test_colored_noise_guaranteed_with_batched_tensor(self):
random.seed(42)
mixed_inputs = self.cl_noise_transform_guaranteed(
self.input_audios, self.sample_rate
)
).samples
self.assertFalse(torch.equal(mixed_inputs, self.input_audios))
self.assertEqual(mixed_inputs.size(0), self.input_audios.size(0))
self.assertEqual(mixed_inputs.size(1), self.input_audios.size(1))
Expand Down
24 changes: 9 additions & 15 deletions tests/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_compose(self):
)
processed_samples = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).numpy()
).samples.numpy()
expected_factor = -convert_decibels_to_amplitude_ratio(-6)
assert_almost_equal(
processed_samples,
Expand All @@ -41,7 +41,7 @@ def test_compose_with_torchaudio_transform(self):
augment = Compose([Vol(gain=-6, gain_type="db"), PolarityInversion(p=1.0)])
processed_samples = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).numpy()
).samples.numpy()
expected_factor = -convert_decibels_to_amplitude_ratio(-6)
assert_almost_equal(
processed_samples,
Expand All @@ -64,7 +64,7 @@ def test_compose_with_p_zero(self):
)
processed_samples = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).numpy()
).samples.numpy()
assert_array_equal(samples, processed_samples)

def test_freeze_and_unfreeze_parameters(self):
Expand All @@ -82,17 +82,17 @@ def test_freeze_and_unfreeze_parameters(self):

processed_samples1 = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).numpy()
).samples.numpy()
augment.freeze_parameters()
processed_samples2 = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).numpy()
).samples.numpy()
assert_array_equal(processed_samples1, processed_samples2)

augment.unfreeze_parameters()
processed_samples3 = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).numpy()
).samples.numpy()
self.assertNotEqual(processed_samples1[0, 0, 0], processed_samples3[0, 0, 0])

def test_shuffle(self):
Expand All @@ -112,7 +112,7 @@ def test_shuffle(self):
for i in range(100):
processed_samples = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).numpy()
).samples.numpy()

# Either PeakNormalization or Gain was applied last
if processed_samples[0, 0, 0] < 0.2:
Expand All @@ -126,14 +126,8 @@ def test_shuffle(self):
self.assertGreater(num_gain_last, 10)

def test_supported_modes_property(self):
augment = Compose(
transforms=[
PeakNormalization(p=1.0),
],
)
augment = Compose(transforms=[PeakNormalization(p=1.0),],)
assert augment.supported_modes == {"per_batch", "per_example", "per_channel"}

augment = Compose(
transforms=[PeakNormalization(p=1.0), ShuffleChannels(p=1.0)],
)
augment = Compose(transforms=[PeakNormalization(p=1.0), ShuffleChannels(p=1.0)],)
assert augment.supported_modes == {"per_example"}
2 changes: 1 addition & 1 deletion tests/test_differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_transform_is_differentiable(augment):
optim = SGD([samples], lr=1.0)
for i in range(10):
optim.zero_grad()
transformed = augment(samples=samples, sample_rate=sample_rate)
transformed = augment(samples=samples, sample_rate=sample_rate).samples
# Compute mean absolute error
loss = torch.mean(torch.abs(samples - transformed))
loss.backward()
Expand Down
28 changes: 15 additions & 13 deletions tests/test_gain.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_gain(self):
augment = Gain(min_gain_in_db=-6.000001, max_gain_in_db=-6, p=1.0)
processed_samples = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).numpy()
).samples.numpy()
expected_factor = convert_decibels_to_amplitude_ratio(-6)
assert_almost_equal(
processed_samples,
Expand All @@ -43,7 +43,7 @@ def test_gain_per_channel(self):
)
processed_samples = augment(
samples=torch.from_numpy(samples_batch), sample_rate=sample_rate
).numpy()
).samples.numpy()

num_unprocessed_channels = 0
num_processed_channels = 0
Expand Down Expand Up @@ -140,7 +140,7 @@ def test_gain_per_channel_with_p_mode_per_batch(self):
for i in range(100):
processed_samples = augment(
samples=torch.from_numpy(samples_batch), sample_rate=sample_rate
).numpy()
).samples.numpy()

if np.allclose(processed_samples, samples_batch):
continue
Expand Down Expand Up @@ -235,7 +235,7 @@ def test_gain_per_channel_with_p_mode_per_example(self):
)
processed_samples = augment(
samples=torch.from_numpy(samples_batch), sample_rate=None
).numpy()
).samples.numpy()

num_unprocessed_examples = 0
num_processed_examples = 0
Expand Down Expand Up @@ -331,7 +331,7 @@ def test_gain_per_batch(self):
for i in range(1000):
processed_samples = augment(
samples=torch.from_numpy(samples_batch), sample_rate=sample_rate
).numpy()
).samples.numpy()

estimated_gain_factors = processed_samples / samples_batch
self.assertAlmostEqual(
Expand All @@ -354,7 +354,7 @@ def test_eval(self):

processed_samples = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).numpy()
).samples.numpy()

np.testing.assert_array_equal(samples, processed_samples)

Expand All @@ -366,7 +366,7 @@ def test_variability_within_batch(self):
augment = Gain(min_gain_in_db=-6, max_gain_in_db=6, p=0.5)
processed_samples = augment(
samples=torch.from_numpy(samples_batch), sample_rate=sample_rate
).numpy()
).samples.numpy()
self.assertEqual(processed_samples.dtype, np.float32)

num_unprocessed_examples = 0
Expand Down Expand Up @@ -406,7 +406,7 @@ def test_variability_within_batch_with_p_mode_per_batch(self):
for _ in range(100):
processed_samples = augment(
samples=torch.from_numpy(samples_batch), sample_rate=sample_rate
).numpy()
).samples.numpy()
self.assertEqual(processed_samples.dtype, np.float32)

if np.allclose(processed_samples, samples_batch):
Expand Down Expand Up @@ -456,7 +456,9 @@ def test_reset_distribution(self):
# Change the parameters after init
augment.min_gain_in_db = -18
augment.max_gain_in_db = 3
processed_samples = augment(samples=torch.from_numpy(samples_batch)).numpy()
processed_samples = augment(
samples=torch.from_numpy(samples_batch)
).samples.numpy()
self.assertEqual(processed_samples.dtype, np.float32)

actual_gains_in_db = []
Expand Down Expand Up @@ -490,7 +492,7 @@ def test_cuda_reset_distribution(self):
augment(
samples=torch.from_numpy(samples_batch).cuda(), sample_rate=sample_rate
)
.cpu()
.samples.cpu()
.numpy()
)
self.assertEqual(processed_samples.dtype, np.float32)
Expand Down Expand Up @@ -538,7 +540,7 @@ def test_gain_to_device_cuda(self):
samples=torch.from_numpy(samples).to(device=cuda_device),
sample_rate=sample_rate,
)
.cpu()
.samples.cpu()
.numpy()
)
expected_factor = convert_decibels_to_amplitude_ratio(-6)
Expand All @@ -558,7 +560,7 @@ def test_gain_cuda(self):
augment = Gain(min_gain_in_db=-6.000001, max_gain_in_db=-6, p=1.0).cuda()
processed_samples = (
augment(samples=torch.from_numpy(samples).cuda(), sample_rate=sample_rate)
.cpu()
.samples.cpu()
.numpy()
)
expected_factor = convert_decibels_to_amplitude_ratio(-6)
Expand All @@ -578,7 +580,7 @@ def test_gain_cuda_cpu(self):
augment = Gain(min_gain_in_db=-6.000001, max_gain_in_db=-6, p=1.0).cuda().cpu()
processed_samples = (
augment(samples=torch.from_numpy(samples).cpu(), sample_rate=sample_rate)
.cpu()
.samples.cpu()
.numpy()
)
expected_factor = convert_decibels_to_amplitude_ratio(-6)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_high_pass_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_high_pass_filter(self):
augment = HighPassFilter(p=1.0)
processed_samples = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).numpy()
).samples.numpy()
self.assertEqual(processed_samples.shape, samples.shape)
self.assertEqual(processed_samples.dtype, np.float32)

Expand All @@ -41,7 +41,7 @@ def test_high_pass_filter_cuda(self):
augment = HighPassFilter(p=1.0)
processed_samples = (
augment(samples=torch.from_numpy(samples).cuda(), sample_rate=sample_rate)
.cpu()
.samples.cpu()
.numpy()
)
self.assertEqual(processed_samples.shape, samples.shape)
Expand Down
Loading