diff --git a/tests/test_background_noise.py b/tests/test_background_noise.py index 0a87d50b..b10a904f 100644 --- a/tests/test_background_noise.py +++ b/tests/test_background_noise.py @@ -21,7 +21,7 @@ class TestAddBackgroundNoise(unittest.TestCase): def setUp(self): self.sample_rate = 16000 self.batch_size = 16 - self.empty_input_audio = torch.empty(0) + self.empty_input_audio = torch.empty(0, 1, 16000) # TODO: use utils.io.Audio self.input_audio = ( torch.from_numpy( @@ -47,7 +47,7 @@ def setUp(self): def test_background_noise_no_guarantee_with_single_tensor(self): mixed_input = self.bg_noise_transform_no_guarantee( self.input_audio, self.sample_rate - ) + ).samples self.assertTrue(torch.equal(mixed_input, self.input_audio)) self.assertEqual(mixed_input.size(0), self.input_audio.size(0)) @@ -55,7 +55,7 @@ def test_background_noise_no_guarantee_with_empty_tensor(self): with self.assertWarns(UserWarning) as warning_context_manager: mixed_input = self.bg_noise_transform_no_guarantee( self.empty_input_audio, self.sample_rate - ) + ).samples self.assertIn( "An empty samples tensor was passed", str(warning_context_manager.warning) @@ -69,7 +69,7 @@ def test_background_noise_guaranteed_with_zero_length_samples(self): with self.assertWarns(UserWarning) as warning_context_manager: mixed_input = self.bg_noise_transform_guaranteed( self.empty_input_audio, self.sample_rate - ) + ).samples self.assertIn( "An empty samples tensor was passed", str(warning_context_manager.warning) @@ -81,7 +81,7 @@ def test_background_noise_guaranteed_with_zero_length_samples(self): def test_background_noise_guaranteed_with_single_tensor(self): mixed_input = self.bg_noise_transform_guaranteed( self.input_audio, self.sample_rate - ) + ).samples self.assertFalse(torch.equal(mixed_input, self.input_audio)) self.assertEqual(mixed_input.size(0), self.input_audio.size(0)) self.assertEqual(mixed_input.size(1), self.input_audio.size(1)) @@ -90,7 +90,7 @@ def test_background_noise_guaranteed_with_batched_tensor(self): random.seed(42) mixed_inputs = self.bg_noise_transform_guaranteed( self.input_audios, self.sample_rate - ) + ).samples self.assertFalse(torch.equal(mixed_inputs, self.input_audios)) self.assertEqual(mixed_inputs.size(0), self.input_audios.size(0)) self.assertEqual(mixed_inputs.size(1), self.input_audios.size(1)) @@ -98,7 +98,7 @@ def test_background_noise_guaranteed_with_batched_tensor(self): def test_background_short_noise_guaranteed_with_batched_tensor(self): mixed_input = self.bg_short_noise_transform_guaranteed( self.input_audio, self.sample_rate - ) + ).samples self.assertFalse(torch.equal(mixed_input, self.input_audio)) self.assertEqual(mixed_input.size(0), self.input_audio.size(0)) self.assertEqual(mixed_input.size(1), self.input_audio.size(1)) @@ -108,7 +108,7 @@ def test_background_short_noise_guaranteed_with_batched_cuda_tensor(self): input_audio_cuda = self.input_audio.cuda() mixed_input = self.bg_short_noise_transform_guaranteed( input_audio_cuda, self.sample_rate - ) + ).samples assert not torch.equal(mixed_input, input_audio_cuda) assert mixed_input.shape == input_audio_cuda.shape assert mixed_input.dtype == input_audio_cuda.dtype @@ -120,7 +120,7 @@ def test_varying_snr_within_batch(self): augment = AddBackgroundNoise( self.bg_path, min_snr_in_db=min_snr_in_db, max_snr_in_db=max_snr_in_db, p=1.0 ) - augmented_audios = augment(self.input_audios, self.sample_rate) + augmented_audios = augment(self.input_audios, self.sample_rate).samples self.assertEqual(tuple(augmented_audios.shape), tuple(self.input_audios.shape)) self.assertFalse(torch.equal(augmented_audios, self.input_audios)) @@ -150,7 +150,7 @@ def test_min_equals_max(self): augment = AddBackgroundNoise( self.bg_path, min_snr_in_db=desired_snr, max_snr_in_db=desired_snr, p=1.0 ) - augmented_audios = augment(self.input_audios, self.sample_rate) + augmented_audios = augment(self.input_audios, self.sample_rate).samples self.assertEqual(tuple(augmented_audios.shape), tuple(self.input_audios.shape)) self.assertFalse(torch.equal(augmented_audios, self.input_audios)) @@ -171,11 +171,9 @@ def test_compatibility_of_resampled_length(self): input_sample_rate = random.randint(1000, 5000) bg_sample_rate = random.randint(1000, 5000) - noise = np.random.uniform( - low=-0.2, - high=0.2, - size=(bg_length,), - ).astype(np.float32) + noise = np.random.uniform(low=-0.2, high=0.2, size=(bg_length,),).astype( + np.float32 + ) tmp_dir = os.path.join(tempfile.gettempdir(), str(uuid.uuid4())) try: os.makedirs(tmp_dir) diff --git a/tests/test_band_pass_filter.py b/tests/test_band_pass_filter.py index 98698c19..03dfe59b 100644 --- a/tests/test_band_pass_filter.py +++ b/tests/test_band_pass_filter.py @@ -22,6 +22,6 @@ def test_band_pass_filter(self): for _ in range(20): processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() self.assertEqual(processed_samples.shape, samples.shape) self.assertEqual(processed_samples.dtype, np.float32) diff --git a/tests/test_band_stop_filter.py b/tests/test_band_stop_filter.py index 2a56716d..5eb32a99 100644 --- a/tests/test_band_stop_filter.py +++ b/tests/test_band_stop_filter.py @@ -21,6 +21,6 @@ def test_band_reject_filter(self): augment = BandStopFilter(p=1.0) processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() self.assertEqual(processed_samples.shape, samples.shape) self.assertEqual(processed_samples.dtype, np.float32) diff --git a/tests/test_colored_noise.py b/tests/test_colored_noise.py index cf8ab35c..4d7cfbf5 100644 --- a/tests/test_colored_noise.py +++ b/tests/test_colored_noise.py @@ -13,7 +13,7 @@ def setUp(self): self.sample_rate = 16000 self.audio = Audio(sample_rate=self.sample_rate) self.batch_size = 16 - self.empty_input_audio = torch.empty(0) + self.empty_input_audio = torch.empty(0, 1, 16000) self.input_audio = self.audio( TEST_FIXTURES_DIR / "acoustic_guitar_0.wav" @@ -26,7 +26,7 @@ def setUp(self): def test_colored_noise_no_guarantee_with_single_tensor(self): mixed_input = self.cl_noise_transform_no_guarantee( self.input_audio, self.sample_rate - ) + ).samples self.assertTrue(torch.equal(mixed_input, self.input_audio)) self.assertEqual(mixed_input.size(0), self.input_audio.size(0)) @@ -34,7 +34,7 @@ def test_background_noise_no_guarantee_with_empty_tensor(self): with self.assertWarns(UserWarning) as warning_context_manager: mixed_input = self.cl_noise_transform_no_guarantee( self.empty_input_audio, self.sample_rate - ) + ).samples self.assertIn( "An empty samples tensor was passed", str(warning_context_manager.warning) @@ -48,7 +48,7 @@ def test_colored_noise_guaranteed_with_zero_length_samples(self): with self.assertWarns(UserWarning) as warning_context_manager: mixed_input = self.cl_noise_transform_guaranteed( self.empty_input_audio, self.sample_rate - ) + ).samples self.assertIn( "An empty samples tensor was passed", str(warning_context_manager.warning) @@ -60,7 +60,7 @@ def test_colored_noise_guaranteed_with_zero_length_samples(self): def test_colored_noise_guaranteed_with_single_tensor(self): mixed_input = self.cl_noise_transform_guaranteed( self.input_audio, self.sample_rate - ) + ).samples self.assertFalse(torch.equal(mixed_input, self.input_audio)) self.assertEqual(mixed_input.size(0), self.input_audio.size(0)) self.assertEqual(mixed_input.size(1), self.input_audio.size(1)) @@ -69,7 +69,7 @@ def test_colored_noise_guaranteed_with_batched_tensor(self): random.seed(42) mixed_inputs = self.cl_noise_transform_guaranteed( self.input_audios, self.sample_rate - ) + ).samples self.assertFalse(torch.equal(mixed_inputs, self.input_audios)) self.assertEqual(mixed_inputs.size(0), self.input_audios.size(0)) self.assertEqual(mixed_inputs.size(1), self.input_audios.size(1)) diff --git a/tests/test_compose.py b/tests/test_compose.py index 0a2d3e86..3750978b 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -24,7 +24,7 @@ def test_compose(self): ) processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() expected_factor = -convert_decibels_to_amplitude_ratio(-6) assert_almost_equal( processed_samples, @@ -41,7 +41,7 @@ def test_compose_with_torchaudio_transform(self): augment = Compose([Vol(gain=-6, gain_type="db"), PolarityInversion(p=1.0)]) processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() expected_factor = -convert_decibels_to_amplitude_ratio(-6) assert_almost_equal( processed_samples, @@ -64,7 +64,7 @@ def test_compose_with_p_zero(self): ) processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() assert_array_equal(samples, processed_samples) def test_freeze_and_unfreeze_parameters(self): @@ -82,17 +82,17 @@ def test_freeze_and_unfreeze_parameters(self): processed_samples1 = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() augment.freeze_parameters() processed_samples2 = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() assert_array_equal(processed_samples1, processed_samples2) augment.unfreeze_parameters() processed_samples3 = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() self.assertNotEqual(processed_samples1[0, 0, 0], processed_samples3[0, 0, 0]) def test_shuffle(self): @@ -112,7 +112,7 @@ def test_shuffle(self): for i in range(100): processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() # Either PeakNormalization or Gain was applied last if processed_samples[0, 0, 0] < 0.2: @@ -126,14 +126,8 @@ def test_shuffle(self): self.assertGreater(num_gain_last, 10) def test_supported_modes_property(self): - augment = Compose( - transforms=[ - PeakNormalization(p=1.0), - ], - ) + augment = Compose(transforms=[PeakNormalization(p=1.0),],) assert augment.supported_modes == {"per_batch", "per_example", "per_channel"} - augment = Compose( - transforms=[PeakNormalization(p=1.0), ShuffleChannels(p=1.0)], - ) + augment = Compose(transforms=[PeakNormalization(p=1.0), ShuffleChannels(p=1.0)],) assert augment.supported_modes == {"per_example"} diff --git a/tests/test_differentiable.py b/tests/test_differentiable.py index d0b96eef..29c7ed09 100644 --- a/tests/test_differentiable.py +++ b/tests/test_differentiable.py @@ -65,7 +65,7 @@ def test_transform_is_differentiable(augment): optim = SGD([samples], lr=1.0) for i in range(10): optim.zero_grad() - transformed = augment(samples=samples, sample_rate=sample_rate) + transformed = augment(samples=samples, sample_rate=sample_rate).samples # Compute mean absolute error loss = torch.mean(torch.abs(samples - transformed)) loss.backward() diff --git a/tests/test_gain.py b/tests/test_gain.py index 6510a4a6..a0695a42 100644 --- a/tests/test_gain.py +++ b/tests/test_gain.py @@ -20,7 +20,7 @@ def test_gain(self): augment = Gain(min_gain_in_db=-6.000001, max_gain_in_db=-6, p=1.0) processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() expected_factor = convert_decibels_to_amplitude_ratio(-6) assert_almost_equal( processed_samples, @@ -43,7 +43,7 @@ def test_gain_per_channel(self): ) processed_samples = augment( samples=torch.from_numpy(samples_batch), sample_rate=sample_rate - ).numpy() + ).samples.numpy() num_unprocessed_channels = 0 num_processed_channels = 0 @@ -140,7 +140,7 @@ def test_gain_per_channel_with_p_mode_per_batch(self): for i in range(100): processed_samples = augment( samples=torch.from_numpy(samples_batch), sample_rate=sample_rate - ).numpy() + ).samples.numpy() if np.allclose(processed_samples, samples_batch): continue @@ -235,7 +235,7 @@ def test_gain_per_channel_with_p_mode_per_example(self): ) processed_samples = augment( samples=torch.from_numpy(samples_batch), sample_rate=None - ).numpy() + ).samples.numpy() num_unprocessed_examples = 0 num_processed_examples = 0 @@ -331,7 +331,7 @@ def test_gain_per_batch(self): for i in range(1000): processed_samples = augment( samples=torch.from_numpy(samples_batch), sample_rate=sample_rate - ).numpy() + ).samples.numpy() estimated_gain_factors = processed_samples / samples_batch self.assertAlmostEqual( @@ -354,7 +354,7 @@ def test_eval(self): processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() np.testing.assert_array_equal(samples, processed_samples) @@ -366,7 +366,7 @@ def test_variability_within_batch(self): augment = Gain(min_gain_in_db=-6, max_gain_in_db=6, p=0.5) processed_samples = augment( samples=torch.from_numpy(samples_batch), sample_rate=sample_rate - ).numpy() + ).samples.numpy() self.assertEqual(processed_samples.dtype, np.float32) num_unprocessed_examples = 0 @@ -406,7 +406,7 @@ def test_variability_within_batch_with_p_mode_per_batch(self): for _ in range(100): processed_samples = augment( samples=torch.from_numpy(samples_batch), sample_rate=sample_rate - ).numpy() + ).samples.numpy() self.assertEqual(processed_samples.dtype, np.float32) if np.allclose(processed_samples, samples_batch): @@ -456,7 +456,9 @@ def test_reset_distribution(self): # Change the parameters after init augment.min_gain_in_db = -18 augment.max_gain_in_db = 3 - processed_samples = augment(samples=torch.from_numpy(samples_batch)).numpy() + processed_samples = augment( + samples=torch.from_numpy(samples_batch) + ).samples.numpy() self.assertEqual(processed_samples.dtype, np.float32) actual_gains_in_db = [] @@ -490,7 +492,7 @@ def test_cuda_reset_distribution(self): augment( samples=torch.from_numpy(samples_batch).cuda(), sample_rate=sample_rate ) - .cpu() + .samples.cpu() .numpy() ) self.assertEqual(processed_samples.dtype, np.float32) @@ -538,7 +540,7 @@ def test_gain_to_device_cuda(self): samples=torch.from_numpy(samples).to(device=cuda_device), sample_rate=sample_rate, ) - .cpu() + .samples.cpu() .numpy() ) expected_factor = convert_decibels_to_amplitude_ratio(-6) @@ -558,7 +560,7 @@ def test_gain_cuda(self): augment = Gain(min_gain_in_db=-6.000001, max_gain_in_db=-6, p=1.0).cuda() processed_samples = ( augment(samples=torch.from_numpy(samples).cuda(), sample_rate=sample_rate) - .cpu() + .samples.cpu() .numpy() ) expected_factor = convert_decibels_to_amplitude_ratio(-6) @@ -578,7 +580,7 @@ def test_gain_cuda_cpu(self): augment = Gain(min_gain_in_db=-6.000001, max_gain_in_db=-6, p=1.0).cuda().cpu() processed_samples = ( augment(samples=torch.from_numpy(samples).cpu(), sample_rate=sample_rate) - .cpu() + .samples.cpu() .numpy() ) expected_factor = convert_decibels_to_amplitude_ratio(-6) diff --git a/tests/test_high_pass_filter.py b/tests/test_high_pass_filter.py index fc9d6fa3..d297c0c8 100644 --- a/tests/test_high_pass_filter.py +++ b/tests/test_high_pass_filter.py @@ -22,7 +22,7 @@ def test_high_pass_filter(self): augment = HighPassFilter(p=1.0) processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() self.assertEqual(processed_samples.shape, samples.shape) self.assertEqual(processed_samples.dtype, np.float32) @@ -41,7 +41,7 @@ def test_high_pass_filter_cuda(self): augment = HighPassFilter(p=1.0) processed_samples = ( augment(samples=torch.from_numpy(samples).cuda(), sample_rate=sample_rate) - .cpu() + .samples.cpu() .numpy() ) self.assertEqual(processed_samples.shape, samples.shape) diff --git a/tests/test_impulse_response.py b/tests/test_impulse_response.py index a6c533d4..00ea21b8 100644 --- a/tests/test_impulse_response.py +++ b/tests/test_impulse_response.py @@ -49,14 +49,13 @@ def ir_transform_no_guarantee(ir_path, sample_rate): def test_impulse_response_guaranteed_with_single_tensor_input(ir_transform, input_audio): - mixed_input = ir_transform(input_audio) + mixed_input = ir_transform(input_audio).samples assert mixed_input.shape == input_audio.shape assert not torch.equal(mixed_input, input_audio) @pytest.mark.parametrize( - "compensate_for_propagation_delay", - [False, True], + "compensate_for_propagation_delay", [False, True], ) def test_impulse_response_guaranteed_with_batched_tensor_input( ir_path, sample_rate, input_audios, compensate_for_propagation_delay @@ -66,7 +65,7 @@ def test_impulse_response_guaranteed_with_batched_tensor_input( compensate_for_propagation_delay=compensate_for_propagation_delay, p=1.0, sample_rate=sample_rate, - )(input_audios) + )(input_audios).samples assert mixed_inputs.shape == input_audios.shape assert not torch.equal(mixed_inputs, input_audios) @@ -76,7 +75,7 @@ def test_impulse_response_guaranteed_with_batched_cuda_tensor_input( input_audios, ir_transform ): input_audio_cuda = input_audios.cuda() - mixed_inputs = ir_transform(input_audio_cuda) + mixed_inputs = ir_transform(input_audio_cuda).samples assert not torch.equal(mixed_inputs, input_audio_cuda) assert mixed_inputs.shape == input_audio_cuda.shape assert mixed_inputs.dtype == input_audio_cuda.dtype @@ -86,21 +85,21 @@ def test_impulse_response_guaranteed_with_batched_cuda_tensor_input( def test_impulse_response_no_guarantee_with_single_tensor_input( input_audio, ir_transform_no_guarantee ): - mixed_input = ir_transform_no_guarantee(input_audio) + mixed_input = ir_transform_no_guarantee(input_audio).samples assert mixed_input.shape == input_audio.shape def test_impulse_response_no_guarantee_with_batched_tensor_input( input_audios, ir_transform_no_guarantee ): - mixed_inputs = ir_transform_no_guarantee(input_audios) + mixed_inputs = ir_transform_no_guarantee(input_audios).samples assert mixed_inputs.shape == input_audios.shape def test_impulse_response_guaranteed_with_zero_length_samples(ir_transform): - empty_audio = torch.empty(0) + empty_audio = torch.empty(0, 1, 16000) with pytest.warns(UserWarning, match="An empty samples tensor was passed"): - mixed_inputs = ir_transform(empty_audio) + mixed_inputs = ir_transform(empty_audio).samples assert torch.equal(mixed_inputs, empty_audio) @@ -108,7 +107,7 @@ def test_impulse_response_guaranteed_with_zero_length_samples(ir_transform): def test_impulse_response_access_file_paths(ir_path, sample_rate, input_audios): augment = ApplyImpulseResponse(ir_path, p=1.0, sample_rate=sample_rate) - mixed_inputs = augment(samples=input_audios, sample_rate=sample_rate) + mixed_inputs = augment(samples=input_audios, sample_rate=sample_rate).samples assert mixed_inputs.shape == input_audios.shape diff --git a/tests/test_low_pass_filter.py b/tests/test_low_pass_filter.py index e65708f8..9bb5559e 100644 --- a/tests/test_low_pass_filter.py +++ b/tests/test_low_pass_filter.py @@ -21,6 +21,6 @@ def test_low_pass_filter(self): augment = LowPassFilter(min_cutoff_freq=200, max_cutoff_freq=7000, p=1.0) processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() self.assertEqual(processed_samples.shape, samples.shape) self.assertEqual(processed_samples.dtype, np.float32) diff --git a/tests/test_mix.py b/tests/test_mix.py new file mode 100644 index 00000000..da95a7f3 --- /dev/null +++ b/tests/test_mix.py @@ -0,0 +1,97 @@ +import unittest + +import torch + +from torch_audiomentations import Mix +from torch_audiomentations.utils.dsp import calculate_rms +from torch_audiomentations.utils.file import load_audio +from .utils import TEST_FIXTURES_DIR + + +class TestMix(unittest.TestCase): + def setUp(self): + self.sample_rate = 16000 + self.guitar = ( + torch.from_numpy( + load_audio( + TEST_FIXTURES_DIR / "acoustic_guitar_0.wav", + sample_rate=self.sample_rate, + ) + ) + .unsqueeze(0) + .unsqueeze(0) + ) + self.noise = ( + torch.from_numpy( + load_audio( + TEST_FIXTURES_DIR / "bg" / "bg.wav", sample_rate=self.sample_rate, + ) + ) + .unsqueeze(0) + .unsqueeze(0) + ) + + common_num_samples = min(self.guitar.shape[-1], self.noise.shape[-1]) + self.guitar = self.guitar[:, :, :common_num_samples] + + self.guitar_target = torch.zeros( + (1, 1, common_num_samples // 7, 2), dtype=torch.int64 + ) + self.guitar_target[:, :, :, 0] = 1 + + self.noise = self.noise[:, :, :common_num_samples] + self.noise_target = torch.zeros( + (1, 1, common_num_samples // 7, 2), dtype=torch.int64 + ) + self.noise_target[:, :, :, 1] = 1 + + self.input_audios = torch.cat([self.guitar, self.noise], dim=0) + self.input_targets = torch.cat([self.guitar_target, self.noise_target], dim=0) + + def test_varying_snr_within_batch(self): + min_snr_in_db = 3 + max_snr_in_db = 30 + augment = Mix(min_snr_in_db=min_snr_in_db, max_snr_in_db=max_snr_in_db, p=1.0) + mixed_audios = augment(self.input_audios, self.sample_rate).samples + + self.assertEqual(tuple(mixed_audios.shape), tuple(self.input_audios.shape)) + self.assertFalse(torch.equal(mixed_audios, self.input_audios)) + + added_audios = mixed_audios - self.input_audios + + for i in range(len(self.input_audios)): + signal_rms = calculate_rms(self.input_audios[i]) + added_rms = calculate_rms(added_audios[i]) + snr_in_db = 20 * torch.log10(signal_rms / added_rms).item() + self.assertGreaterEqual(snr_in_db, min_snr_in_db) + self.assertLessEqual(snr_in_db, max_snr_in_db) + + def test_targets_union(self): + augment = Mix(p=1.0, mix_target="union") + mixtures = augment( + samples=self.input_audios, + sample_rate=self.sample_rate, + targets=self.input_targets, + ) + mixed_targets = mixtures.targets + + # check guitar target is still active in first (guitar) sample + self.assertTrue( + torch.equal(mixed_targets[0, :, :, 0], self.input_targets[0, :, :, 0]) + ) + # check noise target is still active in second (noise) sample + self.assertTrue( + torch.equal(mixed_targets[1, :, :, 1], self.input_targets[1, :, :, 1]) + ) + + def test_targets_original(self): + augment = Mix(p=1.0, mix_target="original") + mixtures = augment( + samples=self.input_audios, + sample_rate=self.sample_rate, + targets=self.input_targets, + ) + mixed_targets = mixtures.targets + + self.assertTrue(torch.equal(mixed_targets, self.input_targets)) + diff --git a/tests/test_peak_normalization.py b/tests/test_peak_normalization.py index 7fb3d1fc..278bb63e 100644 --- a/tests/test_peak_normalization.py +++ b/tests/test_peak_normalization.py @@ -23,7 +23,7 @@ def test_apply_to_all(self): augment = PeakNormalization(p=1.0) processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() assert_almost_equal( processed_samples, @@ -52,7 +52,7 @@ def test_apply_to_only_too_loud_sounds(self): augment = PeakNormalization(apply_to="only_too_loud_sounds", p=1.0) processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() assert_almost_equal( processed_samples, @@ -78,7 +78,7 @@ def test_digital_silence_in_batch(self): augment = PeakNormalization(p=1.0) processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() assert_almost_equal( processed_samples, @@ -102,7 +102,7 @@ def test_only_digital_silence(self): augment = PeakNormalization(p=1.0) processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() assert_almost_equal( processed_samples, @@ -127,7 +127,7 @@ def test_never_apply(self): augment = PeakNormalization(p=0.0) processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() assert_equal( processed_samples, @@ -157,7 +157,7 @@ def test_apply_to_all_cuda(self): augment = PeakNormalization(p=1.0) processed_samples = ( augment(samples=torch.from_numpy(samples).cuda(), sample_rate=sample_rate) - .cpu() + .samples.cpu() .numpy() ) @@ -182,7 +182,7 @@ def test_variability_within_batch(self): augment = PeakNormalization(p=0.5) processed_samples = augment( samples=torch.from_numpy(samples_batch), sample_rate=sample_rate - ).numpy() + ).samples.numpy() self.assertEqual(processed_samples.dtype, np.float32) num_unprocessed_examples = 0 @@ -209,11 +209,13 @@ def test_freeze_parameters(self): sample_rate = 16000 augment = PeakNormalization(p=1.0) - _ = augment(samples=torch.from_numpy(samples1), sample_rate=sample_rate).numpy() + _ = augment( + samples=torch.from_numpy(samples1), sample_rate=sample_rate + ).samples.numpy() augment.freeze_parameters() processed_samples2 = augment( samples=torch.from_numpy(samples2), sample_rate=sample_rate - ).numpy() + ).samples.numpy() augment.unfreeze_parameters() assert_almost_equal( @@ -242,7 +244,7 @@ def test_stereo_sound(self): augment = PeakNormalization(apply_to="all", p=1.0) processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() assert_almost_equal( processed_samples, diff --git a/tests/test_pitch_shift.py b/tests/test_pitch_shift.py index 0e9f4080..7d5e384a 100644 --- a/tests/test_pitch_shift.py +++ b/tests/test_pitch_shift.py @@ -21,30 +21,30 @@ def get_example(): class TestPitchShift(unittest.TestCase): def test_per_example_shift(self): samples = get_example() - aug = PitchShift(16000, p=1, mode="per_example") + aug = PitchShift(sample_rate=16000, p=1, mode="per_example") aug.randomize_parameters(samples) - results = aug.apply_transform(samples) + results = aug.apply_transform(samples).samples self.assertEqual(results.shape, samples.shape) def test_per_channel_shift(self): samples = get_example() - aug = PitchShift(16000, p=1, mode="per_channel") + aug = PitchShift(sample_rate=16000, p=1, mode="per_channel") aug.randomize_parameters(samples) - results = aug.apply_transform(samples) + results = aug.apply_transform(samples).samples self.assertEqual(results.shape, samples.shape) def test_per_batch_shift(self): samples = get_example() - aug = PitchShift(16000, p=1, mode="per_batch") + aug = PitchShift(sample_rate=16000, p=1, mode="per_batch") aug.randomize_parameters(samples) - results = aug.apply_transform(samples) + results = aug.apply_transform(samples).samples self.assertEqual(results.shape, samples.shape) def error_raised(self): error = False try: PitchShift( - 16000, + sample_rate=16000, p=1, mode="per_example", min_transpose_semitones=0.0, diff --git a/tests/test_polarity_inversion.py b/tests/test_polarity_inversion.py index f5e7ed32..59cb5268 100644 --- a/tests/test_polarity_inversion.py +++ b/tests/test_polarity_inversion.py @@ -16,7 +16,7 @@ def test_polarity_inversion(self): augment = PolarityInversion(p=1.0) inverted_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() assert_almost_equal( inverted_samples, np.array([[[-1.0, -0.5, 0.25, 0.125, 0.0]]], dtype=np.float32), @@ -30,7 +30,7 @@ def test_polarity_inversion_zero_probability(self): augment = PolarityInversion(p=0.0) processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() assert_almost_equal( processed_samples, np.array([[[1.0, 0.5, -0.25, -0.125, 0.0]]], dtype=np.float32), @@ -45,7 +45,7 @@ def test_polarity_inversion_variability_within_batch(self): augment = PolarityInversion(p=0.5) processed_samples = augment( samples=torch.from_numpy(samples_batch), sample_rate=sample_rate - ).numpy() + ).samples.numpy() num_unprocessed_examples = 0 num_processed_examples = 0 @@ -71,7 +71,7 @@ def test_polarity_inversion_multichannel(self): augment = PolarityInversion(p=1.0) inverted_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() assert_almost_equal( inverted_samples, np.array( @@ -89,7 +89,7 @@ def test_polarity_inversion_cuda(self): augment = PolarityInversion(p=1.0).cuda() inverted_samples = ( augment(samples=torch.from_numpy(samples).cuda(), sample_rate=sample_rate) - .cpu() + .samples.cpu() .numpy() ) assert_almost_equal( diff --git a/tests/test_shift.py b/tests/test_shift.py index 9ba1f012..148c0fce 100644 --- a/tests/test_shift.py +++ b/tests/test_shift.py @@ -27,7 +27,7 @@ def test_shift_by_1_sample_3dim(self, device_name): samples[1] += 1 augment = Shift(min_shift=1, max_shift=1, shift_unit="samples", p=1.0) - processed_samples = augment(samples) + processed_samples = augment(samples).samples assert_almost_equal( processed_samples.cpu(), @@ -42,7 +42,7 @@ def test_shift_by_1_sample_without_rollover(self): min_shift=1, max_shift=1, shift_unit="samples", rollover=False, p=1.0 ) - processed_samples = augment(samples=samples) + processed_samples = augment(samples=samples).samples assert_almost_equal( processed_samples, [[[0, 0, 1, 2], [0, 0, 1, 2]], [[0, 1, 2, 3], [0, 1, 2, 3]]], @@ -56,7 +56,7 @@ def test_negative_shift_by_2_samples(self): min_shift=-2, max_shift=-2, shift_unit="samples", rollover=True, p=1.0 ) - processed_samples = augment(samples=samples) + processed_samples = augment(samples=samples).samples assert_almost_equal( processed_samples, [[[2, 3, 0, 1], [2, 3, 0, 1]], [[3, 4, 1, 2], [3, 4, 1, 2]]], @@ -70,7 +70,7 @@ def test_shift_by_fraction(self): min_shift=0.5, max_shift=0.5, shift_unit="fraction", rollover=True, p=1.0 ) - processed_samples = augment(samples=samples) + processed_samples = augment(samples=samples).samples assert_almost_equal( processed_samples, [[[2, 3, 0, 1], [2, 3, 0, 1]], [[3, 4, 1, 2], [3, 4, 1, 2]]], @@ -83,7 +83,7 @@ def test_shift_by_seconds(self): augment = Shift( min_shift=-2, max_shift=-2, shift_unit="seconds", p=1.0, sample_rate=1 ) - processed_samples = augment(samples) + processed_samples = augment(samples).samples assert_almost_equal( processed_samples, @@ -105,7 +105,7 @@ def test_shift_by_seconds_specify_sample_rate_in_both_init_and_forward(self): rollover=False, ) # If sample_rate is specified in both __init__ and forward, then the latter will be used - processed_samples = augment(samples, sample_rate=forward_sample_rate) + processed_samples = augment(samples, sample_rate=forward_sample_rate).samples assert_almost_equal( processed_samples, [[[0, 0, 0, 1], [0, 0, 0, 1]], [[0, 0, 1, 2], [0, 0, 1, 2]]], @@ -127,7 +127,7 @@ def test_variability_within_batch(self): samples = torch.arange(4)[None, None].repeat(1000, 2, 1) augment = Shift(min_shift=-1, max_shift=1, shift_unit="samples", p=1.0) - processed_samples = augment(samples) + processed_samples = augment(samples).samples applied_shift_counts = {-1: 0, 0: 0, 1: 0} for i in range(samples.shape[0]): diff --git a/tests/test_shuffle_channels.py b/tests/test_shuffle_channels.py index 052fd3c5..e264dafb 100644 --- a/tests/test_shuffle_channels.py +++ b/tests/test_shuffle_channels.py @@ -10,15 +10,12 @@ class TestShuffleChannels: def test_shuffle_mono(self): samples = torch.from_numpy( - np.array( - [[[1.0, -1.0, 1.0, -1.0, 1.0]]], - dtype=np.float32, - ) + np.array([[[1.0, -1.0, 1.0, -1.0, 1.0]]], dtype=np.float32,) ) augment = ShuffleChannels(p=1.0) with pytest.warns(UserWarning): - processed_samples = augment(samples) + processed_samples = augment(samples).samples assert_array_equal(samples.numpy(), processed_samples.numpy()) @@ -39,14 +36,13 @@ def test_variability_within_batch(self, device_name): torch.manual_seed(42) samples = np.array( - [[1.0, -1.0, 1.0, -1.0, 1.0], [0.1, -0.1, 0.1, -0.1, 0.1]], - dtype=np.float32, + [[1.0, -1.0, 1.0, -1.0, 1.0], [0.1, -0.1, 0.1, -0.1, 0.1]], dtype=np.float32, ) samples = np.stack([samples] * 1000, axis=0) samples = torch.from_numpy(samples).to(device) augment = ShuffleChannels(p=1.0) - processed_samples = augment(samples) + processed_samples = augment(samples).samples orders = {"original": 0, "swapped": 0} for i in range(processed_samples.shape[0]): diff --git a/tests/test_someof.py b/tests/test_someof.py index be6ac0d7..a36c4007 100644 --- a/tests/test_someof.py +++ b/tests/test_someof.py @@ -23,21 +23,27 @@ def test_someof(self): augment = SomeOf(2, self.transforms) self.assertEqual(len(augment.transform_indexes), 0) # no transforms applied yet - processed_samples = augment(samples=self.audio, sample_rate=self.sample_rate) + processed_samples = augment( + samples=self.audio, sample_rate=self.sample_rate + ).samples self.assertEqual(len(augment.transform_indexes), 2) # 2 transforms applied def test_someof_with_p_zero(self): augment = SomeOf(2, self.transforms, p=0.0) self.assertEqual(len(augment.transform_indexes), 0) # no transforms applied yet - processed_samples = augment(samples=self.audio, sample_rate=self.sample_rate) + processed_samples = augment( + samples=self.audio, sample_rate=self.sample_rate + ).samples self.assertEqual(len(augment.transform_indexes), 0) # 0 transforms applied def test_someof_tuple(self): augment = SomeOf((1, None), self.transforms) self.assertEqual(len(augment.transform_indexes), 0) # no transforms applied yet - processed_samples = augment(samples=self.audio, sample_rate=self.sample_rate) + processed_samples = augment( + samples=self.audio, sample_rate=self.sample_rate + ).samples self.assertTrue( len(augment.transform_indexes) > 0 ) # at least one transform applied @@ -51,7 +57,7 @@ def test_someof_freeze_and_unfreeze_parameters(self): self.assertEqual(len(augment.transform_indexes), 0) # no transforms applied yet processed_samples1 = augment( samples=samples, sample_rate=self.sample_rate - ).numpy() + ).samples.numpy() transform_indexes1 = augment.transform_indexes self.assertEqual(len(augment.transform_indexes), 2) @@ -59,7 +65,7 @@ def test_someof_freeze_and_unfreeze_parameters(self): processed_samples2 = augment( samples=samples, sample_rate=self.sample_rate - ).numpy() + ).samples.numpy() transform_indexes2 = augment.transform_indexes assert_array_equal(processed_samples1, processed_samples2) assert_array_equal(transform_indexes1, transform_indexes2) diff --git a/tests/test_time_inversion.py b/tests/test_time_inversion.py index b65865b4..ca1bc505 100644 --- a/tests/test_time_inversion.py +++ b/tests/test_time_inversion.py @@ -12,7 +12,7 @@ def setUp(self): def test_single_channel(self): samples = self.samples.unsqueeze(0).unsqueeze(0) # (B, C, T): (1, 1, 100) - processed_samples = self.augment(samples=samples, sample_rate=16000) + processed_samples = self.augment(samples=samples, sample_rate=16000).samples self.assertEqual(processed_samples.shape, samples.shape) self.assertTrue( @@ -25,7 +25,7 @@ def test_multi_channel(self): samples = torch.stack([self.samples, self.samples], dim=0).unsqueeze( 0 ) # (B, C, T): (1, 2, 100) - processed_samples = self.augment(samples=samples, sample_rate=16000) + processed_samples = self.augment(samples=samples, sample_rate=16000).samples self.assertEqual(processed_samples.shape, samples.shape) self.assertTrue( diff --git a/torch_audiomentations/__init__.py b/torch_audiomentations/__init__.py index aab77651..75c71c1f 100644 --- a/torch_audiomentations/__init__.py +++ b/torch_audiomentations/__init__.py @@ -6,6 +6,7 @@ from .augmentations.high_pass_filter import HighPassFilter from .augmentations.impulse_response import ApplyImpulseResponse from .augmentations.low_pass_filter import LowPassFilter +from .augmentations.mix import Mix from .augmentations.peak_normalization import PeakNormalization from .augmentations.pitch_shift import PitchShift from .augmentations.polarity_inversion import PolarityInversion diff --git a/torch_audiomentations/augmentations/background_noise.py b/torch_audiomentations/augmentations/background_noise.py index 2e391995..ef5d9ad5 100644 --- a/torch_audiomentations/augmentations/background_noise.py +++ b/torch_audiomentations/augmentations/background_noise.py @@ -1,13 +1,15 @@ import random from pathlib import Path -from typing import Union, List +from typing import Union, List, Optional import torch +from torch import Tensor from ..core.transforms_interface import BaseWaveformTransform, EmptyPathException from ..utils.dsp import calculate_rms from ..utils.file import find_audio_files from ..utils.io import Audio +from ..utils.object_dict import ObjectDict class AddBackgroundNoise(BaseWaveformTransform): @@ -15,11 +17,16 @@ class AddBackgroundNoise(BaseWaveformTransform): Add background noise to the input audio. """ + supported_modes = {"per_batch", "per_example", "per_channel"} + # Note: This transform has only partial support for multichannel audio. Noises that are not # mono get mixed down to mono before they are added to all channels in the input. supports_multichannel = True requires_sample_rate = True + supports_target = True + requires_target = False + def __init__( self, background_paths: Union[List[Path], List[str], Path, str], @@ -29,6 +36,7 @@ def __init__( p: float = 0.5, p_mode: str = None, sample_rate: int = None, + target_rate: int = None, ): """ @@ -42,7 +50,13 @@ def __init__( :param sample_rate: """ - super().__init__(mode, p, p_mode, sample_rate) + super().__init__( + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, + ) if isinstance(background_paths, (list, tuple, set)): # TODO: check that one can read audio files @@ -94,14 +108,18 @@ def random_background(self, audio: Audio, target_num_samples: int) -> torch.Tens ) def randomize_parameters( - self, selected_samples: torch.Tensor, sample_rate: int = None + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, ): """ - :params selected_samples: (batch_size, num_channels, num_samples) + :params samples: (batch_size, num_channels, num_samples) """ - batch_size, _, num_samples = selected_samples.shape + batch_size, _, num_samples = samples.shape # (batch_size, num_samples) RMS-normalized background noise audio = self.audio if hasattr(self, "audio") else Audio(sample_rate, mono=True) @@ -115,19 +133,15 @@ def randomize_parameters( size=(batch_size,), fill_value=self.min_snr_in_db, dtype=torch.float32, - device=selected_samples.device, + device=samples.device, ) else: snr_distribution = torch.distributions.Uniform( low=torch.tensor( - self.min_snr_in_db, - dtype=torch.float32, - device=selected_samples.device, + self.min_snr_in_db, dtype=torch.float32, device=samples.device, ), high=torch.tensor( - self.max_snr_in_db, - dtype=torch.float32, - device=selected_samples.device, + self.max_snr_in_db, dtype=torch.float32, device=samples.device, ), validate_args=True, ) @@ -135,17 +149,28 @@ def randomize_parameters( sample_shape=(batch_size,) ) - def apply_transform(self, selected_samples: torch.Tensor, sample_rate: int = None): - batch_size, num_channels, num_samples = selected_samples.shape + def apply_transform( + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: + batch_size, num_channels, num_samples = samples.shape # (batch_size, num_samples) - background = self.transform_parameters["background"].to(selected_samples.device) + background = self.transform_parameters["background"].to(samples.device) # (batch_size, num_channels) - background_rms = calculate_rms(selected_samples) / ( + background_rms = calculate_rms(samples) / ( 10 ** (self.transform_parameters["snr_in_db"].unsqueeze(dim=-1) / 20) ) - return selected_samples + background_rms.unsqueeze(-1) * background.view( - batch_size, 1, num_samples - ).expand(-1, num_channels, -1) + return ObjectDict( + samples=samples + + background_rms.unsqueeze(-1) + * background.view(batch_size, 1, num_samples).expand(-1, num_channels, -1), + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, + ) diff --git a/torch_audiomentations/augmentations/band_pass_filter.py b/torch_audiomentations/augmentations/band_pass_filter.py index 8291f15a..8c58623f 100644 --- a/torch_audiomentations/augmentations/band_pass_filter.py +++ b/torch_audiomentations/augmentations/band_pass_filter.py @@ -1,8 +1,10 @@ import julius import torch - +from torch import Tensor +from typing import Optional from ..core.transforms_interface import BaseWaveformTransform from ..utils.mel_scale import convert_frequencies_to_mels, convert_mels_to_frequencies +from ..utils.object_dict import ObjectDict class BandPassFilter(BaseWaveformTransform): @@ -10,9 +12,14 @@ class BandPassFilter(BaseWaveformTransform): Apply band-pass filtering to the input audio. """ + supported_modes = {"per_batch", "per_example", "per_channel"} + supports_multichannel = True requires_sample_rate = True + supports_target = True + requires_target = False + def __init__( self, min_center_frequency=200, @@ -23,6 +30,7 @@ def __init__( p: float = 0.5, p_mode: str = None, sample_rate: int = None, + target_rate: int = None, ): """ :param min_center_frequency: Minimum center frequency in hertz @@ -36,7 +44,13 @@ def __init__( :param p_mode: :param sample_rate: """ - super().__init__(mode, p, p_mode, sample_rate) + super().__init__( + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, + ) self.min_center_frequency = min_center_frequency self.max_center_frequency = max_center_frequency @@ -65,29 +79,26 @@ def __init__( ) def randomize_parameters( - self, selected_samples: torch.Tensor, sample_rate: int = None + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, ): """ - :params selected_samples: (batch_size, num_channels, num_samples) + :params samples: (batch_size, num_channels, num_samples) """ - batch_size, _, num_samples = selected_samples.shape + + batch_size, _, num_samples = samples.shape # Sample frequencies uniformly in mel space, then convert back to frequency def get_dist(min_freq, max_freq): dist = torch.distributions.Uniform( low=convert_frequencies_to_mels( - torch.tensor( - min_freq, - dtype=torch.float32, - device=selected_samples.device, - ) + torch.tensor(min_freq, dtype=torch.float32, device=samples.device,) ), high=convert_frequencies_to_mels( - torch.tensor( - max_freq, - dtype=torch.float32, - device=selected_samples.device, - ) + torch.tensor(max_freq, dtype=torch.float32, device=samples.device,) ), validate_args=True, ) @@ -99,18 +110,20 @@ def get_dist(min_freq, max_freq): ) bandwidth_dist = torch.distributions.Uniform( - low=self.min_bandwidth_fraction, - high=self.max_bandwidth_fraction, + low=self.min_bandwidth_fraction, high=self.max_bandwidth_fraction, ) self.transform_parameters["bandwidth"] = bandwidth_dist.sample( sample_shape=(batch_size,) ) - def apply_transform(self, selected_samples: torch.Tensor, sample_rate: int = None): - batch_size, num_channels, num_samples = selected_samples.shape - - if sample_rate is None: - sample_rate = self.sample_rate + def apply_transform( + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: + batch_size, num_channels, num_samples = samples.shape low_cutoffs_as_fraction_of_sample_rate = ( self.transform_parameters["center_freq"] @@ -124,10 +137,15 @@ def apply_transform(self, selected_samples: torch.Tensor, sample_rate: int = Non ) # TODO: Instead of using a for loop, perform batched compute to speed things up for i in range(batch_size): - selected_samples[i] = julius.bandpass_filter( - selected_samples[i], + samples[i] = julius.bandpass_filter( + samples[i], cutoff_low=low_cutoffs_as_fraction_of_sample_rate[i].item(), cutoff_high=high_cutoffs_as_fraction_of_sample_rate[i].item(), ) - return selected_samples + return ObjectDict( + samples=samples, + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, + ) diff --git a/torch_audiomentations/augmentations/band_stop_filter.py b/torch_audiomentations/augmentations/band_stop_filter.py index 21c33eea..272d9f1e 100644 --- a/torch_audiomentations/augmentations/band_stop_filter.py +++ b/torch_audiomentations/augmentations/band_stop_filter.py @@ -1,6 +1,8 @@ -import torch +from torch import Tensor +from typing import Optional from ..augmentations.band_pass_filter import BandPassFilter +from ..utils.object_dict import ObjectDict class BandStopFilter(BandPassFilter): @@ -9,9 +11,6 @@ class BandStopFilter(BandPassFilter): band reject filter and frequency mask. """ - supports_multichannel = True - requires_sample_rate = True - def __init__( self, min_center_frequency=200, @@ -22,6 +21,7 @@ def __init__( p: float = 0.5, p_mode: str = None, sample_rate: int = None, + target_rate: int = None, ): """ :param min_center_frequency: Minimum center frequency in hertz @@ -34,6 +34,7 @@ def __init__( :param p: :param p_mode: :param sample_rate: + :param target_rate: """ super().__init__( @@ -41,14 +42,27 @@ def __init__( max_center_frequency, min_bandwidth_fraction, max_bandwidth_fraction, - mode, - p, - p_mode, - sample_rate, + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, ) - def apply_transform(self, selected_samples: torch.Tensor, sample_rate: int = None): - band_pass_filtered_samples = super().apply_transform( - selected_samples.clone(), sample_rate + def apply_transform( + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: + + perturbed = super().apply_transform( + samples.clone(), + sample_rate, + targets=targets.clone() if targets is not None else None, + target_rate=target_rate, ) - return selected_samples - band_pass_filtered_samples + + perturbed.samples = samples - perturbed.samples + return perturbed diff --git a/torch_audiomentations/augmentations/colored_noise.py b/torch_audiomentations/augmentations/colored_noise.py index 0c285b6d..88066705 100644 --- a/torch_audiomentations/augmentations/colored_noise.py +++ b/torch_audiomentations/augmentations/colored_noise.py @@ -1,10 +1,13 @@ import torch +from torch import Tensor +from typing import Optional from math import ceil from torch_audiomentations.utils.fft import rfft, irfft from ..core.transforms_interface import BaseWaveformTransform from ..utils.dsp import calculate_rms from ..utils.io import Audio +from ..utils.object_dict import ObjectDict def _gen_noise(f_decay, num_samples, sample_rate, device): @@ -33,9 +36,14 @@ class AddColoredNoise(BaseWaveformTransform): Add colored noises to the input audio. """ + supported_modes = {"per_batch", "per_example", "per_channel"} + supports_multichannel = True requires_sample_rate = True + supports_target = True + requires_target = False + def __init__( self, min_snr_in_db: float = 3.0, @@ -46,6 +54,7 @@ def __init__( p: float = 0.5, p_mode: str = None, sample_rate: int = None, + target_rate: int = None, ): """ :param min_snr_in_db: minimum SNR in dB. @@ -61,9 +70,16 @@ def __init__( :param p: :param p_mode: :param sample_rate: + :param target_rate: """ - super().__init__(mode, p, p_mode, sample_rate) + super().__init__( + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, + ) self.min_snr_in_db = min_snr_in_db self.max_snr_in_db = max_snr_in_db @@ -76,12 +92,16 @@ def __init__( raise ValueError("min_f_decay must not be greater than max_f_decay") def randomize_parameters( - self, selected_samples: torch.Tensor, sample_rate: int = None + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, ): """ :params selected_samples: (batch_size, num_channels, num_samples) """ - batch_size, _, num_samples = selected_samples.shape + batch_size, _, num_samples = samples.shape # (batch_size, ) SNRs for param, mini, maxi in [ @@ -89,21 +109,21 @@ def randomize_parameters( ("f_decay", self.min_f_decay, self.max_f_decay), ]: dist = torch.distributions.Uniform( - low=torch.tensor( - mini, dtype=torch.float32, device=selected_samples.device - ), - high=torch.tensor( - maxi, dtype=torch.float32, device=selected_samples.device - ), + low=torch.tensor(mini, dtype=torch.float32, device=samples.device), + high=torch.tensor(maxi, dtype=torch.float32, device=samples.device), validate_args=True, ) self.transform_parameters[param] = dist.sample(sample_shape=(batch_size,)) - def apply_transform(self, selected_samples: torch.Tensor, sample_rate: int = None): - batch_size, num_channels, num_samples = selected_samples.shape + def apply_transform( + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: - if sample_rate is None: - sample_rate = self.sample_rate + batch_size, num_channels, num_samples = samples.shape # (batch_size, num_samples) noise = torch.stack( @@ -112,17 +132,22 @@ def apply_transform(self, selected_samples: torch.Tensor, sample_rate: int = Non self.transform_parameters["f_decay"][i], num_samples, sample_rate, - selected_samples.device, + samples.device, ) for i in range(batch_size) ] ) # (batch_size, num_channels) - noise_rms = calculate_rms(selected_samples) / ( + noise_rms = calculate_rms(samples) / ( 10 ** (self.transform_parameters["snr_in_db"].unsqueeze(dim=-1) / 20) ) - return selected_samples + noise_rms.unsqueeze(-1) * noise.view( - batch_size, 1, num_samples - ).expand(-1, num_channels, -1) + return ObjectDict( + samples=samples + + noise_rms.unsqueeze(-1) + * noise.view(batch_size, 1, num_samples).expand(-1, num_channels, -1), + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, + ) diff --git a/torch_audiomentations/augmentations/gain.py b/torch_audiomentations/augmentations/gain.py index 278aacc2..f647987d 100644 --- a/torch_audiomentations/augmentations/gain.py +++ b/torch_audiomentations/augmentations/gain.py @@ -1,8 +1,10 @@ import torch -import typing +from torch import Tensor +from typing import Optional from ..core.transforms_interface import BaseWaveformTransform from ..utils.dsp import convert_decibels_to_amplitude_ratio +from ..utils.object_dict import ObjectDict class Gain(BaseWaveformTransform): @@ -15,36 +17,53 @@ class Gain(BaseWaveformTransform): See also https://en.wikipedia.org/wiki/Clipping_(audio)#Digital_clipping """ + supported_modes = {"per_batch", "per_example", "per_channel"} + + supports_multichannel = True requires_sample_rate = False + supports_target = True + requires_target = False + def __init__( self, min_gain_in_db: float = -18.0, max_gain_in_db: float = 6.0, mode: str = "per_example", p: float = 0.5, - p_mode: typing.Optional[str] = None, - sample_rate: typing.Optional[int] = None, + p_mode: Optional[str] = None, + sample_rate: Optional[int] = None, + target_rate: Optional[int] = None, ): - super().__init__(mode, p, p_mode, sample_rate) + super().__init__( + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, + ) self.min_gain_in_db = min_gain_in_db self.max_gain_in_db = max_gain_in_db if self.min_gain_in_db >= self.max_gain_in_db: raise ValueError("max_gain_in_db must be higher than min_gain_in_db") def randomize_parameters( - self, selected_samples, sample_rate: typing.Optional[int] = None + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, ): distribution = torch.distributions.Uniform( low=torch.tensor( - self.min_gain_in_db, dtype=torch.float32, device=selected_samples.device + self.min_gain_in_db, dtype=torch.float32, device=samples.device ), high=torch.tensor( - self.max_gain_in_db, dtype=torch.float32, device=selected_samples.device + self.max_gain_in_db, dtype=torch.float32, device=samples.device ), validate_args=True, ) - selected_batch_size = selected_samples.size(0) + selected_batch_size = samples.size(0) self.transform_parameters["gain_factors"] = ( convert_decibels_to_amplitude_ratio( distribution.sample(sample_shape=(selected_batch_size,)) @@ -53,5 +72,18 @@ def randomize_parameters( .unsqueeze(1) ) - def apply_transform(self, selected_samples, sample_rate: typing.Optional[int] = None): - return selected_samples * self.transform_parameters["gain_factors"] + def apply_transform( + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: + + return ObjectDict( + samples=samples * self.transform_parameters["gain_factors"], + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, + ) + diff --git a/torch_audiomentations/augmentations/high_pass_filter.py b/torch_audiomentations/augmentations/high_pass_filter.py index 07f023fb..9975318f 100644 --- a/torch_audiomentations/augmentations/high_pass_filter.py +++ b/torch_audiomentations/augmentations/high_pass_filter.py @@ -1,6 +1,8 @@ -import torch +from torch import Tensor +from typing import Optional from ..augmentations.low_pass_filter import LowPassFilter +from ..utils.object_dict import ObjectDict class HighPassFilter(LowPassFilter): @@ -8,9 +10,6 @@ class HighPassFilter(LowPassFilter): Apply high-pass filtering to the input audio. """ - supports_multichannel = True - requires_sample_rate = True - def __init__( self, min_cutoff_freq=20, @@ -19,6 +18,7 @@ def __init__( p: float = 0.5, p_mode: str = None, sample_rate: int = None, + target_rate: int = None, ): """ :param min_cutoff_freq: Minimum cutoff frequency in hertz @@ -27,11 +27,33 @@ def __init__( :param p: :param p_mode: :param sample_rate: + :param target_rate: """ - super().__init__(min_cutoff_freq, max_cutoff_freq, mode, p, p_mode, sample_rate) - def apply_transform(self, selected_samples: torch.Tensor, sample_rate: int = None): - low_pass_filtered_samples = super().apply_transform( - selected_samples.clone(), sample_rate + super().__init__( + min_cutoff_freq, + max_cutoff_freq, + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, + ) + + def apply_transform( + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: + + perturbed = super().apply_transform( + samples=samples.clone(), + sample_rate=sample_rate, + targets=targets.clone() if targets is not None else None, + target_rate=target_rate, ) - return selected_samples - low_pass_filtered_samples + + perturbed.samples = samples - perturbed.samples + return perturbed diff --git a/torch_audiomentations/augmentations/impulse_response.py b/torch_audiomentations/augmentations/impulse_response.py index e0f5d4e5..db60f38d 100644 --- a/torch_audiomentations/augmentations/impulse_response.py +++ b/torch_audiomentations/augmentations/impulse_response.py @@ -1,6 +1,7 @@ import random from pathlib import Path -from typing import Union, List +from typing import Union, List, Optional +from torch import Tensor import torch from torch.nn.utils.rnn import pad_sequence @@ -9,6 +10,7 @@ from ..utils.convolution import convolve from ..utils.file import find_audio_files from ..utils.io import Audio +from ..utils.object_dict import ObjectDict class ApplyImpulseResponse(BaseWaveformTransform): @@ -16,11 +18,16 @@ class ApplyImpulseResponse(BaseWaveformTransform): Convolve the given audio with impulse responses. """ + supported_modes = {"per_batch", "per_example", "per_channel"} + # Note: This transform has only partial support for multichannel audio. IRs that are not # mono get mixed down to mono before they are convolved with all channels in the input. supports_multichannel = True requires_sample_rate = True + supports_target = False # FIXME: some work is needed to support targets (see FIXMEs in apply_transform) + requires_target = False + def __init__( self, ir_paths: Union[List[Path], List[str], Path, str], @@ -30,6 +37,7 @@ def __init__( p: float = 0.5, p_mode: str = None, sample_rate: int = None, + target_rate: int = None, ): """ :param ir_paths: Either a path to a folder with audio files or a list of paths to audio files. @@ -43,8 +51,16 @@ def __init__( :param p: :param p_mode: :param sample_rate: + :param target_rate: """ - super().__init__(mode, p, p_mode, sample_rate) + + super().__init__( + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, + ) if isinstance(ir_paths, (list, tuple, set)): # TODO: check that one can read audio files @@ -61,9 +77,15 @@ def __init__( self.convolve_mode = convolve_mode self.compensate_for_propagation_delay = compensate_for_propagation_delay - def randomize_parameters(self, selected_samples, sample_rate: int = None): + def randomize_parameters( + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ): - batch_size, _, _ = selected_samples.shape + batch_size, _, _ = samples.shape audio = self.audio if hasattr(self, "audio") else Audio(sample_rate, mono=True) @@ -77,20 +99,25 @@ def randomize_parameters(self, selected_samples, sample_rate: int = None): self.transform_parameters["ir_paths"] = random_ir_paths - def apply_transform(self, selected_samples, sample_rate: int = None): + def apply_transform( + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: - batch_size, num_channels, num_samples = selected_samples.shape + batch_size, num_channels, num_samples = samples.shape # (batch_size, 1, max_ir_length) - ir = self.transform_parameters["ir"].to(selected_samples.device) + ir = self.transform_parameters["ir"].to(samples.device) convolved_samples = convolve( - selected_samples, ir.expand(-1, num_channels, -1), mode=self.convolve_mode + samples, ir.expand(-1, num_channels, -1), mode=self.convolve_mode ) if self.compensate_for_propagation_delay: propagation_delays = ir.abs().argmax(dim=2, keepdim=False)[:, 0] - convolved_samples = torch.stack( [ convolved_sample[ @@ -103,7 +130,18 @@ def apply_transform(self, selected_samples, sample_rate: int = None): dim=0, ) - return convolved_samples + return ObjectDict( + samples=convolved_samples, + sample_rate=sample_rate, + targets=targets, # FIXME compensate targets as well? + target_rate=target_rate, + ) else: - return convolved_samples[..., :num_samples] + return ObjectDict( + samples=convolved_samples[..., :num_samples], + sample_rate=sample_rate, + targets=targets, # FIXME crop targets as well? + target_rate=target_rate, + ) + diff --git a/torch_audiomentations/augmentations/low_pass_filter.py b/torch_audiomentations/augmentations/low_pass_filter.py index 1f6812e2..4382cc49 100644 --- a/torch_audiomentations/augmentations/low_pass_filter.py +++ b/torch_audiomentations/augmentations/low_pass_filter.py @@ -1,8 +1,12 @@ import julius import torch +from torch import Tensor +from typing import Optional + from ..core.transforms_interface import BaseWaveformTransform from ..utils.mel_scale import convert_frequencies_to_mels, convert_mels_to_frequencies +from ..utils.object_dict import ObjectDict class LowPassFilter(BaseWaveformTransform): @@ -10,9 +14,14 @@ class LowPassFilter(BaseWaveformTransform): Apply low-pass filtering to the input audio. """ + supported_modes = {"per_batch", "per_example", "per_channel"} + supports_multichannel = True requires_sample_rate = True + supports_target = True + requires_target = False + def __init__( self, min_cutoff_freq=150, @@ -21,6 +30,7 @@ def __init__( p: float = 0.5, p_mode: str = None, sample_rate: int = None, + target_rate: int = None, ): """ :param min_cutoff_freq: Minimum cutoff frequency in hertz @@ -30,7 +40,13 @@ def __init__( :param p_mode: :param sample_rate: """ - super().__init__(mode, p, p_mode, sample_rate) + super().__init__( + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, + ) self.min_cutoff_freq = min_cutoff_freq self.max_cutoff_freq = max_cutoff_freq @@ -38,27 +54,27 @@ def __init__( raise ValueError("min_cutoff_freq must not be greater than max_cutoff_freq") def randomize_parameters( - self, selected_samples: torch.Tensor, sample_rate: int = None + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, ): """ - :params selected_samples: (batch_size, num_channels, num_samples) + :params samples: (batch_size, num_channels, num_samples) """ - batch_size, _, num_samples = selected_samples.shape + batch_size, _, num_samples = samples.shape # Sample frequencies uniformly in mel space, then convert back to frequency dist = torch.distributions.Uniform( low=convert_frequencies_to_mels( torch.tensor( - self.min_cutoff_freq, - dtype=torch.float32, - device=selected_samples.device, + self.min_cutoff_freq, dtype=torch.float32, device=samples.device, ) ), high=convert_frequencies_to_mels( torch.tensor( - self.max_cutoff_freq, - dtype=torch.float32, - device=selected_samples.device, + self.max_cutoff_freq, dtype=torch.float32, device=samples.device, ) ), validate_args=True, @@ -67,19 +83,28 @@ def randomize_parameters( dist.sample(sample_shape=(batch_size,)) ) - def apply_transform(self, selected_samples: torch.Tensor, sample_rate: int = None): - batch_size, num_channels, num_samples = selected_samples.shape + def apply_transform( + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: - if sample_rate is None: - sample_rate = self.sample_rate + batch_size, num_channels, num_samples = samples.shape cutoffs_as_fraction_of_sample_rate = ( self.transform_parameters["cutoff_freq"] / sample_rate ) # TODO: Instead of using a for loop, perform batched compute to speed things up for i in range(batch_size): - selected_samples[i] = julius.lowpass_filter( - selected_samples[i], cutoffs_as_fraction_of_sample_rate[i].item() + samples[i] = julius.lowpass_filter( + samples[i], cutoffs_as_fraction_of_sample_rate[i].item() ) - return selected_samples + return ObjectDict( + samples=samples, + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, + ) diff --git a/torch_audiomentations/augmentations/mix.py b/torch_audiomentations/augmentations/mix.py new file mode 100644 index 00000000..f73678de --- /dev/null +++ b/torch_audiomentations/augmentations/mix.py @@ -0,0 +1,124 @@ +from typing import Optional +import torch +from torch import Tensor + +from ..core.transforms_interface import BaseWaveformTransform +from ..utils.dsp import calculate_rms +from ..utils.io import Audio +from ..utils.object_dict import ObjectDict + + +class Mix(BaseWaveformTransform): + """ + Create a new sample by mixing it with another random sample from the same batch + + Signal-to-noise ratio (where "noise" is the second random sample) is selected + randomly between `min_snr_in_db` and `max_snr_in_db`. + + `mix_target` controls how resulting targets are generated. It can be one of + "original" (targets are those of the original sample) or "union" (targets are the + union of original and overlapping targets) + + """ + + supported_modes = {"per_example", "per_channel"} + + supports_multichannel = True + requires_sample_rate = False + + supports_target = True + requires_target = False + + def __init__( + self, + min_snr_in_db: float = 0.0, + max_snr_in_db: float = 5.0, + mix_target: str = "union", + mode: str = "per_example", + p: float = 0.5, + p_mode: str = None, + sample_rate: int = None, + target_rate: int = None, + ): + super().__init__( + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, + ) + self.min_snr_in_db = min_snr_in_db + self.max_snr_in_db = max_snr_in_db + if self.min_snr_in_db > self.max_snr_in_db: + raise ValueError("min_snr_in_db must not be greater than max_snr_in_db") + + self.mix_target = mix_target + if mix_target == "original": + self._mix_target = lambda target, background_target, snr: target + + elif mix_target == "union": + self._mix_target = lambda target, background_target, snr: torch.maximum( + target, background_target + ) + + else: + raise ValueError("mix_target must be one of 'original' or 'union'.") + + def randomize_parameters( + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ): + + batch_size, num_channels, num_samples = samples.shape + snr_distribution = torch.distributions.Uniform( + low=torch.tensor( + self.min_snr_in_db, dtype=torch.float32, device=samples.device, + ), + high=torch.tensor( + self.max_snr_in_db, dtype=torch.float32, device=samples.device, + ), + validate_args=True, + ) + + # randomize SNRs + self.transform_parameters["snr_in_db"] = snr_distribution.sample( + sample_shape=(batch_size,) + ) + + # randomize index of second sample + self.transform_parameters["sample_idx"] = torch.randint( + 0, batch_size, (batch_size,), device=samples.device, + ) + + def apply_transform( + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: + + snr = self.transform_parameters["snr_in_db"] + idx = self.transform_parameters["sample_idx"] + + background_samples = Audio.rms_normalize(samples[idx]) + background_rms = calculate_rms(samples) / (10 ** (snr.unsqueeze(dim=-1) / 20)) + + mixed_samples = samples + background_rms.unsqueeze(-1) * background_samples + + if targets is None: + mixed_targets = None + + else: + background_targets = targets[idx] + mixed_targets = self._mix_target(targets, background_targets, snr) + + return ObjectDict( + samples=mixed_samples, + sample_rate=sample_rate, + targets=mixed_targets, + target_rate=target_rate, + ) diff --git a/torch_audiomentations/augmentations/peak_normalization.py b/torch_audiomentations/augmentations/peak_normalization.py index 804b7f52..963a1724 100644 --- a/torch_audiomentations/augmentations/peak_normalization.py +++ b/torch_audiomentations/augmentations/peak_normalization.py @@ -1,7 +1,10 @@ import torch import typing +from typing import Optional +from torch import Tensor from ..core.transforms_interface import BaseWaveformTransform +from ..utils.object_dict import ObjectDict class PeakNormalization(BaseWaveformTransform): @@ -16,8 +19,14 @@ class PeakNormalization(BaseWaveformTransform): untouched. """ + supported_modes = {"per_batch", "per_example", "per_channel"} + + supports_multichannel = True requires_sample_rate = False + supports_target = True + requires_target = False + def __init__( self, apply_to="all", @@ -25,16 +34,27 @@ def __init__( p: float = 0.5, p_mode: typing.Optional[str] = None, sample_rate: typing.Optional[int] = None, + target_rate: typing.Optional[int] = None, ): - super().__init__(mode, p, p_mode, sample_rate) + super().__init__( + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, + ) assert apply_to in ("all", "only_too_loud_sounds") self.apply_to = apply_to def randomize_parameters( - self, selected_samples, sample_rate: typing.Optional[int] = None + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, ): # Compute the most extreme value of each multichannel audio snippet in the batch - most_extreme_values, _ = torch.max(torch.abs(selected_samples), dim=-1) + most_extreme_values, _ = torch.max(torch.abs(samples), dim=-1) most_extreme_values, _ = torch.max(most_extreme_values, dim=-1) if self.apply_to == "all": @@ -55,9 +75,22 @@ def randomize_parameters( most_extreme_values[self.transform_parameters["selector"]], (-1, 1, 1) ) - def apply_transform(self, selected_samples, sample_rate: typing.Optional[int] = None): + def apply_transform( + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: + if "divisors" in self.transform_parameters: - selected_samples[ - self.transform_parameters["selector"] - ] /= self.transform_parameters["divisors"] - return selected_samples + samples[self.transform_parameters["selector"]] /= self.transform_parameters[ + "divisors" + ] + + return ObjectDict( + samples=samples, + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, + ) diff --git a/torch_audiomentations/augmentations/pitch_shift.py b/torch_audiomentations/augmentations/pitch_shift.py index b8f27641..dfda3da6 100644 --- a/torch_audiomentations/augmentations/pitch_shift.py +++ b/torch_audiomentations/augmentations/pitch_shift.py @@ -1,9 +1,11 @@ from random import choices -import torch +from torch import Tensor +from typing import Optional from torch_pitch_shift import pitch_shift, get_fast_shifts, semitones_to_ratio from ..core.transforms_interface import BaseWaveformTransform +from ..utils.object_dict import ObjectDict class PitchShift(BaseWaveformTransform): @@ -11,17 +13,23 @@ class PitchShift(BaseWaveformTransform): Pitch-shift sounds up or down without changing the tempo. """ + supported_modes = {"per_batch", "per_example", "per_channel"} + supports_multichannel = True requires_sample_rate = True + supports_target = True + requires_target = False + def __init__( self, - sample_rate: int, min_transpose_semitones: float = -4.0, max_transpose_semitones: float = 4.0, mode: str = "per_example", p: float = 0.5, p_mode: str = None, + sample_rate: int = None, + target_rate: int = None, ): """ :param sample_rate: @@ -30,8 +38,15 @@ def __init__( :param mode: ``per_example``, ``per_channel``, or ``per_batch``. Default ``per_example``. :param p: :param p_mode: + :param target_rate: """ - super().__init__(mode, p, p_mode, sample_rate) + super().__init__( + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, + ) if min_transpose_semitones > max_transpose_semitones: raise ValueError("max_transpose_semitones must be > min_transpose_semitones") @@ -51,13 +66,17 @@ def __init__( self._mode = mode def randomize_parameters( - self, selected_samples: torch.Tensor, sample_rate: int = None + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, ): """ - :param selected_samples: (batch_size, num_channels, num_samples) + :param samples: (batch_size, num_channels, num_samples) :param sample_rate: """ - batch_size, num_channels, num_samples = selected_samples.shape + batch_size, num_channels, num_samples = samples.shape if self._mode == "per_example": self.transform_parameters["transpositions"] = choices( @@ -75,12 +94,18 @@ def randomize_parameters( elif self._mode == "per_batch": self.transform_parameters["transpositions"] = choices(self._fast_shifts, k=1) - def apply_transform(self, selected_samples: torch.Tensor, sample_rate: int = None): + def apply_transform( + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: """ - :param selected_samples: (batch_size, num_channels, num_samples) + :param samples: (batch_size, num_channels, num_samples) :param sample_rate: """ - batch_size, num_channels, num_samples = selected_samples.shape + batch_size, num_channels, num_samples = samples.shape if sample_rate is not None and sample_rate != self._sample_rate: raise ValueError( @@ -91,24 +116,29 @@ def apply_transform(self, selected_samples: torch.Tensor, sample_rate: int = Non if self._mode == "per_example": for i in range(batch_size): - selected_samples[i, ...] = pitch_shift( - selected_samples[i][None], + samples[i, ...] = pitch_shift( + samples[i][None], self.transform_parameters["transpositions"][i], sample_rate, )[0] + elif self._mode == "per_channel": for i in range(batch_size): for j in range(num_channels): - selected_samples[i, j, ...] = pitch_shift( - selected_samples[i][j][None][None], + samples[i, j, ...] = pitch_shift( + samples[i][j][None][None], self.transform_parameters["transpositions"][i][j], sample_rate, )[0][0] + elif self._mode == "per_batch": - return pitch_shift( - selected_samples, - self.transform_parameters["transpositions"][0], - sample_rate, + samples = pitch_shift( + samples, self.transform_parameters["transpositions"][0], sample_rate ) - return selected_samples + return ObjectDict( + samples=samples, + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, + ) diff --git a/torch_audiomentations/augmentations/polarity_inversion.py b/torch_audiomentations/augmentations/polarity_inversion.py index d1d1d260..53461cb5 100644 --- a/torch_audiomentations/augmentations/polarity_inversion.py +++ b/torch_audiomentations/augmentations/polarity_inversion.py @@ -1,6 +1,8 @@ -import typing +from torch import Tensor +from typing import Optional from ..core.transforms_interface import BaseWaveformTransform +from ..utils.object_dict import ObjectDict class PolarityInversion(BaseWaveformTransform): @@ -14,17 +16,42 @@ class PolarityInversion(BaseWaveformTransform): training phase-aware machine learning models. """ + supported_modes = {"per_batch", "per_example", "per_channel"} + supports_multichannel = True requires_sample_rate = False + supports_target = True + requires_target = False + def __init__( self, mode: str = "per_example", p: float = 0.5, - p_mode: typing.Optional[str] = None, - sample_rate: typing.Optional[int] = None, + p_mode: Optional[str] = None, + sample_rate: Optional[int] = None, + target_rate: Optional[int] = None, ): - super().__init__(mode, p, p_mode, sample_rate) + super().__init__( + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, + ) + + def apply_transform( + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: + + return ObjectDict( + samples=-samples, + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, + ) - def apply_transform(self, selected_samples, sample_rate: typing.Optional[int] = None): - return -selected_samples diff --git a/torch_audiomentations/augmentations/shift.py b/torch_audiomentations/augmentations/shift.py index 97b5a17f..fd471717 100644 --- a/torch_audiomentations/augmentations/shift.py +++ b/torch_audiomentations/augmentations/shift.py @@ -1,7 +1,9 @@ import torch -import typing +from typing import Optional +from torch import Tensor from ..core.transforms_interface import BaseWaveformTransform +from ..utils.object_dict import ObjectDict def shift_gpu(tensor: torch.Tensor, r: torch.Tensor, rollover: bool = False): @@ -51,6 +53,14 @@ class Shift(BaseWaveformTransform): Shift the audio forwards or backwards, with or without rollover """ + supported_modes = {"per_batch", "per_example", "per_channel"} + + supports_multichannel = True + requires_sample_rate = True + + supports_target = False # FIXME: some work is needed to support targets (see FIXMEs in apply_transform) + requires_target = False + def __init__( self, min_shift: float = -0.5, @@ -59,8 +69,9 @@ def __init__( rollover: bool = True, mode: str = "per_example", p: float = 0.5, - p_mode: typing.Optional[str] = None, - sample_rate: typing.Optional[int] = None, + p_mode: Optional[str] = None, + sample_rate: Optional[int] = None, + target_rate: Optional[int] = None, ): """ @@ -78,8 +89,15 @@ def __init__( :param p: :param p_mode: :param sample_rate: + :param target_rate: """ - super().__init__(mode, p, p_mode, sample_rate) + super().__init__( + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, + ) self.min_shift = min_shift self.max_shift = max_shift self.shift_unit = shift_unit @@ -90,17 +108,24 @@ def __init__( raise ValueError('shift_unit must be "samples", "fraction" or "seconds"') def randomize_parameters( - self, selected_samples, sample_rate: typing.Optional[int] = None + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, ): if self.shift_unit == "samples": min_shift_in_samples = self.min_shift max_shift_in_samples = self.max_shift + elif self.shift_unit == "fraction": - min_shift_in_samples = int(round(self.min_shift * selected_samples.shape[-1])) - max_shift_in_samples = int(round(self.max_shift * selected_samples.shape[-1])) + min_shift_in_samples = int(round(self.min_shift * samples.shape[-1])) + max_shift_in_samples = int(round(self.max_shift * samples.shape[-1])) + elif self.shift_unit == "seconds": min_shift_in_samples = int(round(self.min_shift * sample_rate)) max_shift_in_samples = int(round(self.max_shift * sample_rate)) + else: raise ValueError("Invalid shift_unit") @@ -114,28 +139,55 @@ def randomize_parameters( <= max_shift_in_samples <= torch.iinfo(torch.int32).max ) - selected_batch_size = selected_samples.size(0) + selected_batch_size = samples.size(0) if min_shift_in_samples == max_shift_in_samples: self.transform_parameters["num_samples_to_shift"] = torch.full( size=(selected_batch_size,), fill_value=min_shift_in_samples, dtype=torch.int32, - device=selected_samples.device, + device=samples.device, ) + else: self.transform_parameters["num_samples_to_shift"] = torch.randint( low=min_shift_in_samples, high=max_shift_in_samples + 1, size=(selected_batch_size,), dtype=torch.int32, - device=selected_samples.device, + device=samples.device, ) - def apply_transform(self, selected_samples, sample_rate: typing.Optional[int] = None): - r = self.transform_parameters["num_samples_to_shift"] + def apply_transform( + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: + + num_samples_to_shift = self.transform_parameters["num_samples_to_shift"] + # Select fastest implementation based on device - shift = shift_gpu if selected_samples.device.type == "cuda" else shift_cpu - return shift(selected_samples, r, self.rollover) + shift = shift_gpu if samples.device.type == "cuda" else shift_cpu + shifted_samples = shift(samples, num_samples_to_shift, self.rollover) + + if targets is None or target_rate == 0: + shifted_targets = targets + + else: + num_frames_to_shift = int( + round(target_rate * num_samples_to_shift / sample_rate) + ) + shifted_targets = shift( + targets.transpose(-2, -1), num_frames_to_shift, self.rollover + ).transpose(-2, -1) + + return ObjectDict( + samples=shifted_samples, + sample_rate=sample_rate, + targets=shifted_targets, + target_rate=target_rate, + ) def is_sample_rate_required(self) -> bool: # Sample rate is required only if shift_unit is "seconds" diff --git a/torch_audiomentations/augmentations/shuffle_channels.py b/torch_audiomentations/augmentations/shuffle_channels.py index 7f06eea5..ad1c15eb 100644 --- a/torch_audiomentations/augmentations/shuffle_channels.py +++ b/torch_audiomentations/augmentations/shuffle_channels.py @@ -1,9 +1,12 @@ -import typing +from typing import Optional import warnings import torch +from torch import Tensor + from ..core.transforms_interface import BaseWaveformTransform +from ..utils.object_dict import ObjectDict class ShuffleChannels(BaseWaveformTransform): @@ -22,33 +25,63 @@ def __init__( self, mode: str = "per_example", p: float = 0.5, - p_mode: typing.Optional[str] = None, - sample_rate: typing.Optional[int] = None, + p_mode: Optional[str] = None, + sample_rate: Optional[int] = None, + target_rate: Optional[int] = None, ): - super().__init__(mode, p, p_mode, sample_rate) + super().__init__( + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, + ) def randomize_parameters( - self, selected_samples, sample_rate: typing.Optional[int] = None + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, ): - batch_size = selected_samples.shape[0] - num_channels = selected_samples.shape[1] + batch_size = samples.shape[0] + num_channels = samples.shape[1] assert num_channels <= 255 permutations = torch.zeros( - (batch_size, num_channels), dtype=torch.int64, device=selected_samples.device + (batch_size, num_channels), dtype=torch.int64, device=samples.device ) for i in range(batch_size): - permutations[i] = torch.randperm(num_channels, device=selected_samples.device) + permutations[i] = torch.randperm(num_channels, device=samples.device) self.transform_parameters["permutations"] = permutations - def apply_transform(self, selected_samples, sample_rate: typing.Optional[int] = None): - if selected_samples.shape[1] == 1: + def apply_transform( + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: + + if samples.shape[1] == 1: warnings.warn( "Mono audio was passed to ShuffleChannels - there are no channels to shuffle." " The input will be returned unchanged." ) - return selected_samples - for i in range(selected_samples.size(0)): - selected_samples[i] = selected_samples[ - i, self.transform_parameters["permutations"][i] - ] - return selected_samples + return ObjectDict( + samples=samples, + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, + ) + + for i in range(samples.size(0)): + samples[i] = samples[i, self.transform_parameters["permutations"][i]] + if targets is not None: + targets[i] = targets[i, self.transform_parameters["permutations"][i]] + + return ObjectDict( + samples=samples, + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, + ) diff --git a/torch_audiomentations/augmentations/time_inversion.py b/torch_audiomentations/augmentations/time_inversion.py index 53191d71..c1ef6676 100644 --- a/torch_audiomentations/augmentations/time_inversion.py +++ b/torch_audiomentations/augmentations/time_inversion.py @@ -1,6 +1,9 @@ import torch +from torch import Tensor +from typing import Optional from ..core.transforms_interface import BaseWaveformTransform +from ..utils.object_dict import ObjectDict class TimeInversion(BaseWaveformTransform): @@ -12,15 +15,21 @@ class TimeInversion(BaseWaveformTransform): https://arxiv.org/pdf/2106.13043.pdf """ + supported_modes = {"per_batch", "per_example", "per_channel"} + supports_multichannel = True requires_sample_rate = False + supports_target = True + requires_target = False + def __init__( self, mode: str = "per_example", p: float = 0.5, p_mode: str = None, sample_rate: int = None, + target_rate: int = None, ): """ :param mode: @@ -28,12 +37,37 @@ def __init__( :param p_mode: :param sample_rate: """ - super().__init__(mode, p, p_mode, sample_rate) + super().__init__( + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, + ) + + def apply_transform( + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: - def apply_transform(self, selected_samples: torch.Tensor, sample_rate: int = None): # torch.flip() is supposed to be slower than np.flip() # An alternative is to use advanced indexing: https://github.com/pytorch/pytorch/issues/16424 # reverse_index = torch.arange(selected_samples.size(-1) - 1, -1, -1).to(selected_samples.device) # transformed_samples = selected_samples[..., reverse_index] # return transformed_samples - return torch.flip(selected_samples, dims=(-1,)) + + flipped_samples = torch.flip(samples, dims=(-1,)) + if targets is None: + flipped_targets = targets + else: + flipped_targets = torch.flip(targets, dims=(-2,)) + + return ObjectDict( + samples=flipped_samples, + sample_rate=sample_rate, + targets=flipped_targets, + target_rate=target_rate, + ) diff --git a/torch_audiomentations/core/composition.py b/torch_audiomentations/core/composition.py index e37bd07f..ab9443ff 100644 --- a/torch_audiomentations/core/composition.py +++ b/torch_audiomentations/core/composition.py @@ -1,10 +1,11 @@ import random -from typing import List +from typing import List, Union, Optional, Tuple -import torch -import typing +from torch import Tensor +import torch.nn from torch_audiomentations.core.transforms_interface import BaseWaveformTransform +from torch_audiomentations.utils.object_dict import ObjectDict class BaseCompose(torch.nn.Module): @@ -12,7 +13,9 @@ class BaseCompose(torch.nn.Module): def __init__( self, - transforms: List[torch.nn.Module], + transforms: List[ + torch.nn.Module + ], # FIXME: do we really want to support regular nn.Module? shuffle: bool = False, p: float = 1.0, p_mode="per_batch", @@ -63,18 +66,36 @@ def supported_modes(self) -> set: class Compose(BaseCompose): - def forward(self, samples, sample_rate: typing.Optional[int] = None): + def forward( + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: + + inputs = ObjectDict( + samples=samples, + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, + ) + if random.random() < self.p: transform_indexes = list(range(len(self.transforms))) if self.shuffle: random.shuffle(transform_indexes) for i in transform_indexes: tfm = self.transforms[i] - if isinstance(tfm, BaseWaveformTransform): - samples = self.transforms[i](samples, sample_rate) + if isinstance(tfm, (BaseWaveformTransform, BaseCompose)): + inputs = self.transforms[i](**inputs) + else: - samples = self.transforms[i](samples) - return samples + assert isinstance(tfm, torch.nn.Module) + # FIXME: do we really want to support regular nn.Module? + inputs.samples = self.transforms[i](inputs.samples) + + return inputs class SomeOf(BaseCompose): @@ -96,7 +117,7 @@ class SomeOf(BaseCompose): def __init__( self, - num_transforms: typing.Union[int, typing.Tuple[int, int]], + num_transforms: Union[int, Tuple[int, int]], transforms: List[torch.nn.Module], p: float = 1.0, p_mode="per_batch", @@ -131,7 +152,21 @@ def randomize_parameters(self): random.sample(self.all_transforms_indexes, num_transforms_to_apply) ) - def forward(self, samples, sample_rate: typing.Optional[int] = None): + def forward( + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: + + inputs = ObjectDict( + samples=samples, + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, + ) + if random.random() < self.p: if not self.are_parameters_frozen: @@ -139,11 +174,15 @@ def forward(self, samples, sample_rate: typing.Optional[int] = None): for i in self.transform_indexes: tfm = self.transforms[i] - if isinstance(tfm, BaseWaveformTransform): - samples = self.transforms[i](samples, sample_rate) + if isinstance(tfm, (BaseWaveformTransform, BaseCompose)): + inputs = self.transforms[i](**inputs) + else: - samples = self.transforms[i](samples) - return samples + assert isinstance(tfm, torch.nn.Module) + # FIXME: do we really want to support regular nn.Module? + inputs.samples = self.transforms[i](inputs.samples) + + return inputs class OneOf(SomeOf): diff --git a/torch_audiomentations/core/transforms_interface.py b/torch_audiomentations/core/transforms_interface.py index 7cc90ca8..bab32c7a 100644 --- a/torch_audiomentations/core/transforms_interface.py +++ b/torch_audiomentations/core/transforms_interface.py @@ -1,10 +1,12 @@ import warnings import torch -import typing +from torch import Tensor +from typing import Optional from torch.distributions import Bernoulli from torch_audiomentations.utils.multichannel import is_multichannel +from torch_audiomentations.utils.object_dict import ObjectDict class MultichannelAudioNotSupportedException(Exception): @@ -20,16 +22,22 @@ class ModeNotSupportedException(Exception): class BaseWaveformTransform(torch.nn.Module): - supports_multichannel = True + supported_modes = {"per_batch", "per_example", "per_channel"} + + supports_multichannel = True requires_sample_rate = True + supports_target = True + requires_target = False + def __init__( self, mode: str = "per_example", p: float = 0.5, - p_mode: typing.Optional[str] = None, - sample_rate: typing.Optional[int] = None, + p_mode: Optional[str] = None, + sample_rate: Optional[int] = None, + target_rate: Optional[int] = None, ): """ @@ -49,6 +57,8 @@ def __init__( with different parameters. Default value: Same as mode. :param sample_rate: sample_rate can be set either here or when calling the transform. + :param target_rate: target_rate can be set either here or when + calling the transform. """ super().__init__() assert 0.0 <= p <= 1.0 @@ -58,6 +68,7 @@ def __init__( if self.p_mode is None: self.p_mode = self.mode self.sample_rate = sample_rate + self.target_rate = target_rate # Check validity of mode/p_mode combination if self.mode not in self.supported_modes: @@ -87,30 +98,52 @@ def p(self, p): # Update the Bernoulli distribution accordingly self.bernoulli_distribution = Bernoulli(self._p) - def forward(self, samples, sample_rate: typing.Optional[int] = None): - if not self.training: - return samples + def forward( + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + # TODO: add support for additional **kwargs (batch_size, ...)-shaped tensors + # TODO: but do that only when we actually have a use case for it... + ) -> ObjectDict: - if len(samples) == 0: - warnings.warn( - "An empty samples tensor was passed to {}".format(self.__class__.__name__) + if not self.training: + return ObjectDict( + samples=samples, + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, ) - return samples - if len(samples.shape) != 3: + if not isinstance(samples, Tensor) or len(samples.shape) != 3: raise RuntimeError( - "torch-audiomentations expects input tensors to be three-dimensional, with" + "torch-audiomentations expects three-dimensional input tensors, with" " dimension ordering like [batch_size, num_channels, num_samples]. If your" " audio is mono, you can use a shape like [batch_size, 1, num_samples]." ) + batch_size, num_channels, num_samples = samples.shape + + if batch_size * num_channels * num_samples == 0: + warnings.warn( + "An empty samples tensor was passed to {}".format(self.__class__.__name__) + ) + return ObjectDict( + samples=samples, + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, + ) + if is_multichannel(samples): - if samples.shape[1] > samples.shape[2]: + if num_channels > num_samples: warnings.warn( "Multichannel audio must have channels first, not channels last. In" " other words, the shape must be (batch size, channels, samples), not" " (batch_size, samples, channels)" ) + if not self.supports_multichannel: raise MultichannelAudioNotSupportedException( "{} only supports mono audio, not multichannel audio".format( @@ -122,15 +155,69 @@ def forward(self, samples, sample_rate: typing.Optional[int] = None): if sample_rate is None and self.is_sample_rate_required(): raise RuntimeError("sample_rate is required") + if targets is None and self.is_target_required(): + raise RuntimeError("targets is required") + + has_targets = targets is not None + + if has_targets and not self.supports_target: + warnings.warn(f"Targets are not (yet) supported by {self.__class__.__name__}") + + if has_targets: + + if not isinstance(targets, Tensor) or len(targets.shape) != 4: + + raise RuntimeError( + "torch-audiomentations expects four-dimensional target tensors, with" + " dimension ordering like [batch_size, num_channels, num_frames, num_classes]." + " If your target is binary, you can use a shape like [batch_size, num_channels, num_frames, 1]." + " If your target is for the whole channel, you can use a shape like [batch_size, num_channels, 1, num_classes]." + ) + + ( + target_batch_size, + target_num_channels, + num_frames, + num_classes, + ) = targets.shape + + if target_batch_size != batch_size: + raise RuntimeError( + f"samples ({batch_size}) and target ({target_batch_size}) batch sizes must be equal." + ) + if num_channels != target_num_channels: + raise RuntimeError( + f"samples ({num_channels}) and target ({target_num_channels}) number of channels must be equal." + ) + + target_rate = target_rate or self.target_rate + if target_rate is None: + if num_frames > 1: + target_rate = round(sample_rate * num_frames / num_samples) + warnings.warn( + f"target_rate is required by {self.__class__.__name__}. " + f"It has been automatically inferred from targets shape to {target_rate}. " + f"If this is incorrect, you can pass it directly." + ) + else: + # corner case where num_frames == 1, meaning that the target is for the whole sample, + # not frame-based. we arbitrarily set target_rate to 0. + target_rate = 0 + if not self.are_parameters_frozen: + if self.p_mode == "per_example": - p_sample_size = samples.shape[0] + p_sample_size = batch_size + elif self.p_mode == "per_channel": - p_sample_size = samples.shape[0] * samples.shape[1] + p_sample_size = batch_size * num_channels + elif self.p_mode == "per_batch": p_sample_size = 1 + else: raise Exception("Invalid mode") + self.transform_parameters = { "should_apply": self.bernoulli_distribution.sample( sample_shape=(p_sample_size,) @@ -138,109 +225,286 @@ def forward(self, samples, sample_rate: typing.Optional[int] = None): } if self.transform_parameters["should_apply"].any(): + cloned_samples = samples.clone() + if has_targets: + cloned_targets = targets.clone() + else: + cloned_targets = None + selected_targets = None + if self.p_mode == "per_channel": - batch_size = cloned_samples.shape[0] - num_channels = cloned_samples.shape[1] + cloned_samples = cloned_samples.reshape( - batch_size * num_channels, 1, cloned_samples.shape[2] + batch_size * num_channels, 1, num_samples ) selected_samples = cloned_samples[ self.transform_parameters["should_apply"] ] + if has_targets: + cloned_targets = cloned_targets.reshape( + batch_size * num_channels, 1, num_frames, num_classes + ) + selected_targets = cloned_targets[ + self.transform_parameters["should_apply"] + ] + if not self.are_parameters_frozen: - self.randomize_parameters(selected_samples, sample_rate) + self.randomize_parameters( + samples=selected_samples, + sample_rate=sample_rate, + targets=selected_targets, + target_rate=target_rate, + ) + + perturbed: ObjectDict = self.apply_transform( + samples=selected_samples, + sample_rate=sample_rate, + targets=selected_targets, + target_rate=target_rate, + ) cloned_samples[ self.transform_parameters["should_apply"] - ] = self.apply_transform(selected_samples, sample_rate) - + ] = perturbed.samples cloned_samples = cloned_samples.reshape( - batch_size, num_channels, cloned_samples.shape[2] + batch_size, num_channels, num_samples ) - return cloned_samples + if has_targets: + cloned_targets[ + self.transform_parameters["should_apply"] + ] = perturbed.targets + cloned_targets = cloned_targets.reshape( + batch_size, num_channels, num_frames, num_classes + ) + + return ObjectDict( + samples=cloned_samples, + sample_rate=perturbed.sample_rate, + targets=cloned_targets, + target_rate=perturbed.target_rate, + ) elif self.p_mode == "per_example": + selected_samples = cloned_samples[ self.transform_parameters["should_apply"] ] + if has_targets: + selected_targets = cloned_targets[ + self.transform_parameters["should_apply"] + ] + if self.mode == "per_example": + if not self.are_parameters_frozen: - self.randomize_parameters(selected_samples, sample_rate) + self.randomize_parameters( + samples=selected_samples, + sample_rate=sample_rate, + targets=selected_targets, + target_rate=target_rate, + ) + + perturbed: ObjectDict = self.apply_transform( + samples=selected_samples, + sample_rate=sample_rate, + targets=selected_targets, + target_rate=target_rate, + ) cloned_samples[ self.transform_parameters["should_apply"] - ] = self.apply_transform(selected_samples, sample_rate) - return cloned_samples + ] = perturbed.samples + + if has_targets: + cloned_targets[ + self.transform_parameters["should_apply"] + ] = perturbed.targets + + return ObjectDict( + samples=cloned_samples, + sample_rate=perturbed.sample_rate, + targets=cloned_targets, + target_rate=perturbed.target_rate, + ) + elif self.mode == "per_channel": - batch_size = selected_samples.shape[0] - num_channels = selected_samples.shape[1] + + ( + selected_batch_size, + selected_num_channels, + selected_num_samples, + ) = selected_samples.shape + + assert selected_num_samples == num_samples + selected_samples = selected_samples.reshape( - batch_size * num_channels, 1, selected_samples.shape[2] + selected_batch_size * selected_num_channels, + 1, + selected_num_samples, ) - if not self.are_parameters_frozen: - self.randomize_parameters(selected_samples, sample_rate) + if has_targets: + selected_targets = selected_targets.reshape( + selected_batch_size * selected_num_channels, + 1, + num_frames, + num_classes, + ) - perturbed_samples = self.apply_transform( - selected_samples, sample_rate - ) - perturbed_samples = perturbed_samples.reshape( - batch_size, num_channels, selected_samples.shape[2] + if not self.are_parameters_frozen: + self.randomize_parameters( + samples=selected_samples, + sample_rate=sample_rate, + targets=selected_targets, + target_rate=target_rate, + ) + + perturbed: ObjectDict = self.apply_transform( + selected_samples, + sample_rate=sample_rate, + targets=selected_targets, + target_rate=target_rate, ) + perturbed.samples = perturbed.samples.reshape( + selected_batch_size, selected_num_channels, selected_num_samples + ) cloned_samples[ self.transform_parameters["should_apply"] - ] = perturbed_samples - return cloned_samples + ] = perturbed.samples + + if has_targets: + perturbed.targets = perturbed.targets.reshape( + selected_batch_size, + selected_num_channels, + num_frames, + num_classes, + ) + cloned_targets[ + self.transform_parameters["should_apply"] + ] = perturbed.targets + + return ObjectDict( + samples=cloned_samples, + sample_rate=perturbed.sample_rate, + targets=cloned_targets, + target_rate=perturbed.target_rate, + ) + else: raise Exception("Invalid mode/p_mode combination") + elif self.p_mode == "per_batch": + if self.mode == "per_batch": - batch_size = cloned_samples.shape[0] - num_channels = cloned_samples.shape[1] + cloned_samples = cloned_samples.reshape( - 1, batch_size * num_channels, cloned_samples.shape[2] + 1, batch_size * num_channels, num_samples ) - if not self.are_parameters_frozen: - self.randomize_parameters(cloned_samples, sample_rate) + if has_targets: + cloned_targets = cloned_targets.reshape( + 1, batch_size * num_channels, num_frames, num_classes + ) - perturbed_samples = self.apply_transform(cloned_samples, sample_rate) - perturbed_samples = perturbed_samples.reshape( - batch_size, num_channels, cloned_samples.shape[2] + if not self.are_parameters_frozen: + self.randomize_parameters( + samples=cloned_samples, + sample_rate=sample_rate, + targets=cloned_targets, + target_rate=target_rate, + ) + + perturbed: ObjectDict = self.apply_transform( + samples=cloned_samples, + sample_rate=sample_rate, + targets=cloned_targets, + target_rate=target_rate, ) - return perturbed_samples + perturbed.samples = perturbed.samples.reshape( + batch_size, num_channels, num_samples + ) + + if has_targets: + perturbed.targets = perturbed.targets.reshape( + batch_size, num_channels, num_frames, num_classes + ) + + return perturbed + elif self.mode == "per_example": + if not self.are_parameters_frozen: - self.randomize_parameters(cloned_samples, sample_rate) - return self.apply_transform(cloned_samples, sample_rate) + self.randomize_parameters( + samples=cloned_samples, + sample_rate=sample_rate, + targets=cloned_targets, + target_rate=target_rate, + ) + + perturbed = self.apply_transform( + samples=cloned_samples, + sample_rate=sample_rate, + targets=cloned_targets, + target_rate=target_rate, + ) + + return perturbed + elif self.mode == "per_channel": - batch_size = cloned_samples.shape[0] - num_channels = cloned_samples.shape[1] + cloned_samples = cloned_samples.reshape( - batch_size * num_channels, 1, cloned_samples.shape[2] + batch_size * num_channels, 1, num_samples ) - if not self.are_parameters_frozen: - self.randomize_parameters(cloned_samples, sample_rate) + if has_targets: + cloned_targets = cloned_targets.reshape( + batch_size * num_channels, 1, num_frames, num_classes + ) - perturbed_samples = self.apply_transform(cloned_samples, sample_rate) + if not self.are_parameters_frozen: + self.randomize_parameters( + samples=cloned_samples, + sample_rate=sample_rate, + targets=cloned_targets, + target_rate=target_rate, + ) + + perturbed: ObjectDict = self.apply_transform( + cloned_samples, + sample_rate, + targets=cloned_targets, + target_rate=target_rate, + ) - perturbed_samples = perturbed_samples.reshape( - batch_size, num_channels, cloned_samples.shape[2] + perturbed.samples = perturbed.samples.reshape( + batch_size, num_channels, num_samples ) - return perturbed_samples + + if has_targets: + perturbed.targets = perturbed.targets.reshape( + batch_size, num_channels, num_frames, num_classes + ) + + return perturbed + else: raise Exception("Invalid mode") + else: raise Exception("Invalid p_mode {}".format(self.p_mode)) - return samples + return ObjectDict( + samples=samples, + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, + ) def _forward_unimplemented(self, *inputs) -> None: # Avoid IDE error message like "Class ... must implement all abstract methods" @@ -248,11 +512,22 @@ def _forward_unimplemented(self, *inputs) -> None: pass def randomize_parameters( - self, selected_samples, sample_rate: typing.Optional[int] = None + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, ): pass - def apply_transform(self, selected_samples, sample_rate: typing.Optional[int] = None): + def apply_transform( + self, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: + raise NotImplementedError() def serialize_parameters(self): @@ -276,3 +551,6 @@ def unfreeze_parameters(self): def is_sample_rate_required(self) -> bool: return self.requires_sample_rate + + def is_target_required(self) -> bool: + return self.requires_target diff --git a/torch_audiomentations/utils/io.py b/torch_audiomentations/utils/io.py index ad829666..dbe2a6ba 100644 --- a/torch_audiomentations/utils/io.py +++ b/torch_audiomentations/utils/io.py @@ -84,16 +84,16 @@ def rms_normalize(samples: Tensor) -> Tensor: Parameters ---------- - samples : (channel, time) Tensor - Single or multichannel samples + samples : (..., time) Tensor + Single (or multichannel) samples or batch of samples Returns ------- - samples: (channel, time) Tensor + samples: (..., time) Tensor Power-normalized samples """ - rms = samples.square().mean(dim=1).sqrt() - return (samples.t() / (rms + 1e-8)).t() + rms = samples.square().mean(dim=-1, keepdim=True).sqrt() + return samples / (rms + 1e-8) @staticmethod def get_audio_metadata(file_path) -> tuple: diff --git a/torch_audiomentations/utils/object_dict.py b/torch_audiomentations/utils/object_dict.py new file mode 100644 index 00000000..f59ad186 --- /dev/null +++ b/torch_audiomentations/utils/object_dict.py @@ -0,0 +1,35 @@ +# Inspired by tornado +# https://www.tornadoweb.org/en/stable/_modules/tornado/util.html#ObjectDict + +import typing + +_ObjectDictBase = typing.Dict[str, typing.Any] + + +class ObjectDict(_ObjectDictBase): + """ + Make a dictionary behave like an object, with attribute-style access. + + Here are some examples of how it can be used: + + o = ObjectDict(my_dict) + # or like this: + o = ObjectDict(samples=samples, sample_rate=sample_rate) + + # Attribute-style access + samples = o.samples + + # Dict-style access + samples = o["samples"] + """ + + def __getattr__(self, name): + # type: (str) -> Any + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name, value): + # type: (str, Any) -> None + self[name] = value