Skip to content

Commit

Permalink
Merge pull request #123 from hbredin/target_support
Browse files Browse the repository at this point in the history
BREAKING: add support for targets as discussed in #3
  • Loading branch information
iver56 authored Apr 1, 2022
2 parents 71ba564 + d252497 commit 7bc37e5
Show file tree
Hide file tree
Showing 37 changed files with 1,352 additions and 373 deletions.
28 changes: 13 additions & 15 deletions tests/test_background_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class TestAddBackgroundNoise(unittest.TestCase):
def setUp(self):
self.sample_rate = 16000
self.batch_size = 16
self.empty_input_audio = torch.empty(0)
self.empty_input_audio = torch.empty(0, 1, 16000)
# TODO: use utils.io.Audio
self.input_audio = (
torch.from_numpy(
Expand All @@ -47,15 +47,15 @@ def setUp(self):
def test_background_noise_no_guarantee_with_single_tensor(self):
mixed_input = self.bg_noise_transform_no_guarantee(
self.input_audio, self.sample_rate
)
).samples
self.assertTrue(torch.equal(mixed_input, self.input_audio))
self.assertEqual(mixed_input.size(0), self.input_audio.size(0))

def test_background_noise_no_guarantee_with_empty_tensor(self):
with self.assertWarns(UserWarning) as warning_context_manager:
mixed_input = self.bg_noise_transform_no_guarantee(
self.empty_input_audio, self.sample_rate
)
).samples

self.assertIn(
"An empty samples tensor was passed", str(warning_context_manager.warning)
Expand All @@ -69,7 +69,7 @@ def test_background_noise_guaranteed_with_zero_length_samples(self):
with self.assertWarns(UserWarning) as warning_context_manager:
mixed_input = self.bg_noise_transform_guaranteed(
self.empty_input_audio, self.sample_rate
)
).samples

self.assertIn(
"An empty samples tensor was passed", str(warning_context_manager.warning)
Expand All @@ -81,7 +81,7 @@ def test_background_noise_guaranteed_with_zero_length_samples(self):
def test_background_noise_guaranteed_with_single_tensor(self):
mixed_input = self.bg_noise_transform_guaranteed(
self.input_audio, self.sample_rate
)
).samples
self.assertFalse(torch.equal(mixed_input, self.input_audio))
self.assertEqual(mixed_input.size(0), self.input_audio.size(0))
self.assertEqual(mixed_input.size(1), self.input_audio.size(1))
Expand All @@ -90,15 +90,15 @@ def test_background_noise_guaranteed_with_batched_tensor(self):
random.seed(42)
mixed_inputs = self.bg_noise_transform_guaranteed(
self.input_audios, self.sample_rate
)
).samples
self.assertFalse(torch.equal(mixed_inputs, self.input_audios))
self.assertEqual(mixed_inputs.size(0), self.input_audios.size(0))
self.assertEqual(mixed_inputs.size(1), self.input_audios.size(1))

def test_background_short_noise_guaranteed_with_batched_tensor(self):
mixed_input = self.bg_short_noise_transform_guaranteed(
self.input_audio, self.sample_rate
)
).samples
self.assertFalse(torch.equal(mixed_input, self.input_audio))
self.assertEqual(mixed_input.size(0), self.input_audio.size(0))
self.assertEqual(mixed_input.size(1), self.input_audio.size(1))
Expand All @@ -108,7 +108,7 @@ def test_background_short_noise_guaranteed_with_batched_cuda_tensor(self):
input_audio_cuda = self.input_audio.cuda()
mixed_input = self.bg_short_noise_transform_guaranteed(
input_audio_cuda, self.sample_rate
)
).samples
assert not torch.equal(mixed_input, input_audio_cuda)
assert mixed_input.shape == input_audio_cuda.shape
assert mixed_input.dtype == input_audio_cuda.dtype
Expand All @@ -120,7 +120,7 @@ def test_varying_snr_within_batch(self):
augment = AddBackgroundNoise(
self.bg_path, min_snr_in_db=min_snr_in_db, max_snr_in_db=max_snr_in_db, p=1.0
)
augmented_audios = augment(self.input_audios, self.sample_rate)
augmented_audios = augment(self.input_audios, self.sample_rate).samples

self.assertEqual(tuple(augmented_audios.shape), tuple(self.input_audios.shape))
self.assertFalse(torch.equal(augmented_audios, self.input_audios))
Expand Down Expand Up @@ -150,7 +150,7 @@ def test_min_equals_max(self):
augment = AddBackgroundNoise(
self.bg_path, min_snr_in_db=desired_snr, max_snr_in_db=desired_snr, p=1.0
)
augmented_audios = augment(self.input_audios, self.sample_rate)
augmented_audios = augment(self.input_audios, self.sample_rate).samples

self.assertEqual(tuple(augmented_audios.shape), tuple(self.input_audios.shape))
self.assertFalse(torch.equal(augmented_audios, self.input_audios))
Expand All @@ -171,11 +171,9 @@ def test_compatibility_of_resampled_length(self):
input_sample_rate = random.randint(1000, 5000)
bg_sample_rate = random.randint(1000, 5000)

noise = np.random.uniform(
low=-0.2,
high=0.2,
size=(bg_length,),
).astype(np.float32)
noise = np.random.uniform(low=-0.2, high=0.2, size=(bg_length,),).astype(
np.float32
)
tmp_dir = os.path.join(tempfile.gettempdir(), str(uuid.uuid4()))
try:
os.makedirs(tmp_dir)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_band_pass_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ def test_band_pass_filter(self):
for _ in range(20):
processed_samples = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).numpy()
).samples.numpy()
self.assertEqual(processed_samples.shape, samples.shape)
self.assertEqual(processed_samples.dtype, np.float32)
2 changes: 1 addition & 1 deletion tests/test_band_stop_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ def test_band_reject_filter(self):
augment = BandStopFilter(p=1.0)
processed_samples = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).numpy()
).samples.numpy()
self.assertEqual(processed_samples.shape, samples.shape)
self.assertEqual(processed_samples.dtype, np.float32)
12 changes: 6 additions & 6 deletions tests/test_colored_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def setUp(self):
self.sample_rate = 16000
self.audio = Audio(sample_rate=self.sample_rate)
self.batch_size = 16
self.empty_input_audio = torch.empty(0)
self.empty_input_audio = torch.empty(0, 1, 16000)

self.input_audio = self.audio(
TEST_FIXTURES_DIR / "acoustic_guitar_0.wav"
Expand All @@ -26,15 +26,15 @@ def setUp(self):
def test_colored_noise_no_guarantee_with_single_tensor(self):
mixed_input = self.cl_noise_transform_no_guarantee(
self.input_audio, self.sample_rate
)
).samples
self.assertTrue(torch.equal(mixed_input, self.input_audio))
self.assertEqual(mixed_input.size(0), self.input_audio.size(0))

def test_background_noise_no_guarantee_with_empty_tensor(self):
with self.assertWarns(UserWarning) as warning_context_manager:
mixed_input = self.cl_noise_transform_no_guarantee(
self.empty_input_audio, self.sample_rate
)
).samples

self.assertIn(
"An empty samples tensor was passed", str(warning_context_manager.warning)
Expand All @@ -48,7 +48,7 @@ def test_colored_noise_guaranteed_with_zero_length_samples(self):
with self.assertWarns(UserWarning) as warning_context_manager:
mixed_input = self.cl_noise_transform_guaranteed(
self.empty_input_audio, self.sample_rate
)
).samples

self.assertIn(
"An empty samples tensor was passed", str(warning_context_manager.warning)
Expand All @@ -60,7 +60,7 @@ def test_colored_noise_guaranteed_with_zero_length_samples(self):
def test_colored_noise_guaranteed_with_single_tensor(self):
mixed_input = self.cl_noise_transform_guaranteed(
self.input_audio, self.sample_rate
)
).samples
self.assertFalse(torch.equal(mixed_input, self.input_audio))
self.assertEqual(mixed_input.size(0), self.input_audio.size(0))
self.assertEqual(mixed_input.size(1), self.input_audio.size(1))
Expand All @@ -69,7 +69,7 @@ def test_colored_noise_guaranteed_with_batched_tensor(self):
random.seed(42)
mixed_inputs = self.cl_noise_transform_guaranteed(
self.input_audios, self.sample_rate
)
).samples
self.assertFalse(torch.equal(mixed_inputs, self.input_audios))
self.assertEqual(mixed_inputs.size(0), self.input_audios.size(0))
self.assertEqual(mixed_inputs.size(1), self.input_audios.size(1))
Expand Down
24 changes: 9 additions & 15 deletions tests/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_compose(self):
)
processed_samples = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).numpy()
).samples.numpy()
expected_factor = -convert_decibels_to_amplitude_ratio(-6)
assert_almost_equal(
processed_samples,
Expand All @@ -41,7 +41,7 @@ def test_compose_with_torchaudio_transform(self):
augment = Compose([Vol(gain=-6, gain_type="db"), PolarityInversion(p=1.0)])
processed_samples = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).numpy()
).samples.numpy()
expected_factor = -convert_decibels_to_amplitude_ratio(-6)
assert_almost_equal(
processed_samples,
Expand All @@ -64,7 +64,7 @@ def test_compose_with_p_zero(self):
)
processed_samples = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).numpy()
).samples.numpy()
assert_array_equal(samples, processed_samples)

def test_freeze_and_unfreeze_parameters(self):
Expand All @@ -82,17 +82,17 @@ def test_freeze_and_unfreeze_parameters(self):

processed_samples1 = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).numpy()
).samples.numpy()
augment.freeze_parameters()
processed_samples2 = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).numpy()
).samples.numpy()
assert_array_equal(processed_samples1, processed_samples2)

augment.unfreeze_parameters()
processed_samples3 = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).numpy()
).samples.numpy()
self.assertNotEqual(processed_samples1[0, 0, 0], processed_samples3[0, 0, 0])

def test_shuffle(self):
Expand All @@ -112,7 +112,7 @@ def test_shuffle(self):
for i in range(100):
processed_samples = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).numpy()
).samples.numpy()

# Either PeakNormalization or Gain was applied last
if processed_samples[0, 0, 0] < 0.2:
Expand All @@ -126,14 +126,8 @@ def test_shuffle(self):
self.assertGreater(num_gain_last, 10)

def test_supported_modes_property(self):
augment = Compose(
transforms=[
PeakNormalization(p=1.0),
],
)
augment = Compose(transforms=[PeakNormalization(p=1.0),],)
assert augment.supported_modes == {"per_batch", "per_example", "per_channel"}

augment = Compose(
transforms=[PeakNormalization(p=1.0), ShuffleChannels(p=1.0)],
)
augment = Compose(transforms=[PeakNormalization(p=1.0), ShuffleChannels(p=1.0)],)
assert augment.supported_modes == {"per_example"}
2 changes: 1 addition & 1 deletion tests/test_differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_transform_is_differentiable(augment):
optim = SGD([samples], lr=1.0)
for i in range(10):
optim.zero_grad()
transformed = augment(samples=samples, sample_rate=sample_rate)
transformed = augment(samples=samples, sample_rate=sample_rate).samples
# Compute mean absolute error
loss = torch.mean(torch.abs(samples - transformed))
loss.backward()
Expand Down
28 changes: 15 additions & 13 deletions tests/test_gain.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_gain(self):
augment = Gain(min_gain_in_db=-6.000001, max_gain_in_db=-6, p=1.0)
processed_samples = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).numpy()
).samples.numpy()
expected_factor = convert_decibels_to_amplitude_ratio(-6)
assert_almost_equal(
processed_samples,
Expand All @@ -43,7 +43,7 @@ def test_gain_per_channel(self):
)
processed_samples = augment(
samples=torch.from_numpy(samples_batch), sample_rate=sample_rate
).numpy()
).samples.numpy()

num_unprocessed_channels = 0
num_processed_channels = 0
Expand Down Expand Up @@ -140,7 +140,7 @@ def test_gain_per_channel_with_p_mode_per_batch(self):
for i in range(100):
processed_samples = augment(
samples=torch.from_numpy(samples_batch), sample_rate=sample_rate
).numpy()
).samples.numpy()

if np.allclose(processed_samples, samples_batch):
continue
Expand Down Expand Up @@ -235,7 +235,7 @@ def test_gain_per_channel_with_p_mode_per_example(self):
)
processed_samples = augment(
samples=torch.from_numpy(samples_batch), sample_rate=None
).numpy()
).samples.numpy()

num_unprocessed_examples = 0
num_processed_examples = 0
Expand Down Expand Up @@ -331,7 +331,7 @@ def test_gain_per_batch(self):
for i in range(1000):
processed_samples = augment(
samples=torch.from_numpy(samples_batch), sample_rate=sample_rate
).numpy()
).samples.numpy()

estimated_gain_factors = processed_samples / samples_batch
self.assertAlmostEqual(
Expand All @@ -354,7 +354,7 @@ def test_eval(self):

processed_samples = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).numpy()
).samples.numpy()

np.testing.assert_array_equal(samples, processed_samples)

Expand All @@ -366,7 +366,7 @@ def test_variability_within_batch(self):
augment = Gain(min_gain_in_db=-6, max_gain_in_db=6, p=0.5)
processed_samples = augment(
samples=torch.from_numpy(samples_batch), sample_rate=sample_rate
).numpy()
).samples.numpy()
self.assertEqual(processed_samples.dtype, np.float32)

num_unprocessed_examples = 0
Expand Down Expand Up @@ -406,7 +406,7 @@ def test_variability_within_batch_with_p_mode_per_batch(self):
for _ in range(100):
processed_samples = augment(
samples=torch.from_numpy(samples_batch), sample_rate=sample_rate
).numpy()
).samples.numpy()
self.assertEqual(processed_samples.dtype, np.float32)

if np.allclose(processed_samples, samples_batch):
Expand Down Expand Up @@ -456,7 +456,9 @@ def test_reset_distribution(self):
# Change the parameters after init
augment.min_gain_in_db = -18
augment.max_gain_in_db = 3
processed_samples = augment(samples=torch.from_numpy(samples_batch)).numpy()
processed_samples = augment(
samples=torch.from_numpy(samples_batch)
).samples.numpy()
self.assertEqual(processed_samples.dtype, np.float32)

actual_gains_in_db = []
Expand Down Expand Up @@ -490,7 +492,7 @@ def test_cuda_reset_distribution(self):
augment(
samples=torch.from_numpy(samples_batch).cuda(), sample_rate=sample_rate
)
.cpu()
.samples.cpu()
.numpy()
)
self.assertEqual(processed_samples.dtype, np.float32)
Expand Down Expand Up @@ -538,7 +540,7 @@ def test_gain_to_device_cuda(self):
samples=torch.from_numpy(samples).to(device=cuda_device),
sample_rate=sample_rate,
)
.cpu()
.samples.cpu()
.numpy()
)
expected_factor = convert_decibels_to_amplitude_ratio(-6)
Expand All @@ -558,7 +560,7 @@ def test_gain_cuda(self):
augment = Gain(min_gain_in_db=-6.000001, max_gain_in_db=-6, p=1.0).cuda()
processed_samples = (
augment(samples=torch.from_numpy(samples).cuda(), sample_rate=sample_rate)
.cpu()
.samples.cpu()
.numpy()
)
expected_factor = convert_decibels_to_amplitude_ratio(-6)
Expand All @@ -578,7 +580,7 @@ def test_gain_cuda_cpu(self):
augment = Gain(min_gain_in_db=-6.000001, max_gain_in_db=-6, p=1.0).cuda().cpu()
processed_samples = (
augment(samples=torch.from_numpy(samples).cpu(), sample_rate=sample_rate)
.cpu()
.samples.cpu()
.numpy()
)
expected_factor = convert_decibels_to_amplitude_ratio(-6)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_high_pass_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_high_pass_filter(self):
augment = HighPassFilter(p=1.0)
processed_samples = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).numpy()
).samples.numpy()
self.assertEqual(processed_samples.shape, samples.shape)
self.assertEqual(processed_samples.dtype, np.float32)

Expand All @@ -41,7 +41,7 @@ def test_high_pass_filter_cuda(self):
augment = HighPassFilter(p=1.0)
processed_samples = (
augment(samples=torch.from_numpy(samples).cuda(), sample_rate=sample_rate)
.cpu()
.samples.cpu()
.numpy()
)
self.assertEqual(processed_samples.shape, samples.shape)
Expand Down
Loading

0 comments on commit 7bc37e5

Please sign in to comment.