From 66fdbb5d8bda8bd28e4d27b46328f30620ca3091 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Tue, 22 Mar 2022 21:28:22 +0100 Subject: [PATCH 01/15] wip: add support for targets as discussed in #3 --- tests/test_pitch_shift.py | 14 +- .../augmentations/background_noise.py | 32 +- .../augmentations/band_pass_filter.py | 36 +- .../augmentations/band_stop_filter.py | 28 +- .../augmentations/colored_noise.py | 33 +- torch_audiomentations/augmentations/gain.py | 25 +- .../augmentations/high_pass_filter.py | 30 +- .../augmentations/impulse_response.py | 34 +- .../augmentations/low_pass_filter.py | 25 +- .../augmentations/peak_normalization.py | 25 +- .../augmentations/pitch_shift.py | 39 ++- .../augmentations/polarity_inversion.py | 20 +- torch_audiomentations/augmentations/shift.py | 52 ++- .../augmentations/shuffle_channels.py | 32 +- .../augmentations/time_inversion.py | 27 +- torch_audiomentations/core/composition.py | 32 +- .../core/transforms_interface.py | 320 +++++++++++++++--- 17 files changed, 664 insertions(+), 140 deletions(-) diff --git a/tests/test_pitch_shift.py b/tests/test_pitch_shift.py index 0e9f4080..a3edcd2e 100644 --- a/tests/test_pitch_shift.py +++ b/tests/test_pitch_shift.py @@ -21,30 +21,30 @@ def get_example(): class TestPitchShift(unittest.TestCase): def test_per_example_shift(self): samples = get_example() - aug = PitchShift(16000, p=1, mode="per_example") + aug = PitchShift(sample_rate=16000, p=1, mode="per_example") aug.randomize_parameters(samples) - results = aug.apply_transform(samples) + results, _ = aug.apply_transform(samples) self.assertEqual(results.shape, samples.shape) def test_per_channel_shift(self): samples = get_example() - aug = PitchShift(16000, p=1, mode="per_channel") + aug = PitchShift(sample_rate=16000, p=1, mode="per_channel") aug.randomize_parameters(samples) - results = aug.apply_transform(samples) + results, _ = aug.apply_transform(samples) self.assertEqual(results.shape, samples.shape) def test_per_batch_shift(self): samples = get_example() - aug = PitchShift(16000, p=1, mode="per_batch") + aug = PitchShift(sample_rate=16000, p=1, mode="per_batch") aug.randomize_parameters(samples) - results = aug.apply_transform(samples) + results, _ = aug.apply_transform(samples) self.assertEqual(results.shape, samples.shape) def error_raised(self): error = False try: PitchShift( - 16000, + sample_rate=16000, p=1, mode="per_example", min_transpose_semitones=0.0, diff --git a/torch_audiomentations/augmentations/background_noise.py b/torch_audiomentations/augmentations/background_noise.py index 2e391995..47262752 100644 --- a/torch_audiomentations/augmentations/background_noise.py +++ b/torch_audiomentations/augmentations/background_noise.py @@ -29,6 +29,7 @@ def __init__( p: float = 0.5, p_mode: str = None, sample_rate: int = None, + target_rate: int = None, ): """ @@ -42,7 +43,13 @@ def __init__( :param sample_rate: """ - super().__init__(mode, p, p_mode, sample_rate) + super().__init__( + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, + ) if isinstance(background_paths, (list, tuple, set)): # TODO: check that one can read audio files @@ -94,7 +101,11 @@ def random_background(self, audio: Audio, target_num_samples: int) -> torch.Tens ) def randomize_parameters( - self, selected_samples: torch.Tensor, sample_rate: int = None + self, + selected_samples: torch.Tensor, + sample_rate: int = None, + targets: torch.Tensor = None, + target_rate: int = None, ): """ @@ -135,7 +146,13 @@ def randomize_parameters( sample_shape=(batch_size,) ) - def apply_transform(self, selected_samples: torch.Tensor, sample_rate: int = None): + def apply_transform( + self, + selected_samples: torch.Tensor, + sample_rate: int = None, + targets: torch.Tensor = None, + target_rate: int = None, + ): batch_size, num_channels, num_samples = selected_samples.shape # (batch_size, num_samples) @@ -146,6 +163,9 @@ def apply_transform(self, selected_samples: torch.Tensor, sample_rate: int = Non 10 ** (self.transform_parameters["snr_in_db"].unsqueeze(dim=-1) / 20) ) - return selected_samples + background_rms.unsqueeze(-1) * background.view( - batch_size, 1, num_samples - ).expand(-1, num_channels, -1) + return ( + selected_samples + + background_rms.unsqueeze(-1) + * background.view(batch_size, 1, num_samples).expand(-1, num_channels, -1), + targets, + ) diff --git a/torch_audiomentations/augmentations/band_pass_filter.py b/torch_audiomentations/augmentations/band_pass_filter.py index 8291f15a..37e7e512 100644 --- a/torch_audiomentations/augmentations/band_pass_filter.py +++ b/torch_audiomentations/augmentations/band_pass_filter.py @@ -23,6 +23,7 @@ def __init__( p: float = 0.5, p_mode: str = None, sample_rate: int = None, + target_rate: int = None, ): """ :param min_center_frequency: Minimum center frequency in hertz @@ -36,7 +37,13 @@ def __init__( :param p_mode: :param sample_rate: """ - super().__init__(mode, p, p_mode, sample_rate) + super().__init__( + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, + ) self.min_center_frequency = min_center_frequency self.max_center_frequency = max_center_frequency @@ -65,7 +72,11 @@ def __init__( ) def randomize_parameters( - self, selected_samples: torch.Tensor, sample_rate: int = None + self, + selected_samples: torch.Tensor, + sample_rate: int = None, + targets: torch.Tensor = None, + target_rate: int = None, ): """ :params selected_samples: (batch_size, num_channels, num_samples) @@ -77,16 +88,12 @@ def get_dist(min_freq, max_freq): dist = torch.distributions.Uniform( low=convert_frequencies_to_mels( torch.tensor( - min_freq, - dtype=torch.float32, - device=selected_samples.device, + min_freq, dtype=torch.float32, device=selected_samples.device, ) ), high=convert_frequencies_to_mels( torch.tensor( - max_freq, - dtype=torch.float32, - device=selected_samples.device, + max_freq, dtype=torch.float32, device=selected_samples.device, ) ), validate_args=True, @@ -99,14 +106,19 @@ def get_dist(min_freq, max_freq): ) bandwidth_dist = torch.distributions.Uniform( - low=self.min_bandwidth_fraction, - high=self.max_bandwidth_fraction, + low=self.min_bandwidth_fraction, high=self.max_bandwidth_fraction, ) self.transform_parameters["bandwidth"] = bandwidth_dist.sample( sample_shape=(batch_size,) ) - def apply_transform(self, selected_samples: torch.Tensor, sample_rate: int = None): + def apply_transform( + self, + selected_samples: torch.Tensor, + sample_rate: int = None, + targets: torch.Tensor = None, + target_rate: int = None, + ): batch_size, num_channels, num_samples = selected_samples.shape if sample_rate is None: @@ -130,4 +142,4 @@ def apply_transform(self, selected_samples: torch.Tensor, sample_rate: int = Non cutoff_high=high_cutoffs_as_fraction_of_sample_rate[i].item(), ) - return selected_samples + return selected_samples, targets diff --git a/torch_audiomentations/augmentations/band_stop_filter.py b/torch_audiomentations/augmentations/band_stop_filter.py index 21c33eea..991b89e7 100644 --- a/torch_audiomentations/augmentations/band_stop_filter.py +++ b/torch_audiomentations/augmentations/band_stop_filter.py @@ -22,6 +22,7 @@ def __init__( p: float = 0.5, p_mode: str = None, sample_rate: int = None, + target_rate: int = None, ): """ :param min_center_frequency: Minimum center frequency in hertz @@ -34,6 +35,7 @@ def __init__( :param p: :param p_mode: :param sample_rate: + :param target_rate: """ super().__init__( @@ -41,14 +43,24 @@ def __init__( max_center_frequency, min_bandwidth_fraction, max_bandwidth_fraction, - mode, - p, - p_mode, - sample_rate, + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, ) - def apply_transform(self, selected_samples: torch.Tensor, sample_rate: int = None): - band_pass_filtered_samples = super().apply_transform( - selected_samples.clone(), sample_rate + def apply_transform( + self, + selected_samples: torch.Tensor, + sample_rate: int = None, + targets: torch.Tensor = None, + target_rate: int = None, + ): + band_pass_filtered_samples, band_pass_filtered_targets = super().apply_transform( + selected_samples.clone(), + sample_rate, + targets=targets.clone() if targets is not None else None, + target_rate=target_rate, ) - return selected_samples - band_pass_filtered_samples + return selected_samples - band_pass_filtered_samples, band_pass_filtered_targets diff --git a/torch_audiomentations/augmentations/colored_noise.py b/torch_audiomentations/augmentations/colored_noise.py index 0c285b6d..6dd30bf6 100644 --- a/torch_audiomentations/augmentations/colored_noise.py +++ b/torch_audiomentations/augmentations/colored_noise.py @@ -46,6 +46,7 @@ def __init__( p: float = 0.5, p_mode: str = None, sample_rate: int = None, + target_rate: int = None, ): """ :param min_snr_in_db: minimum SNR in dB. @@ -61,9 +62,16 @@ def __init__( :param p: :param p_mode: :param sample_rate: + :param target_rate: """ - super().__init__(mode, p, p_mode, sample_rate) + super().__init__( + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, + ) self.min_snr_in_db = min_snr_in_db self.max_snr_in_db = max_snr_in_db @@ -76,7 +84,11 @@ def __init__( raise ValueError("min_f_decay must not be greater than max_f_decay") def randomize_parameters( - self, selected_samples: torch.Tensor, sample_rate: int = None + self, + selected_samples: torch.Tensor, + sample_rate: int = None, + targets: torch.Tensor = None, + target_rate: int = None, ): """ :params selected_samples: (batch_size, num_channels, num_samples) @@ -99,7 +111,13 @@ def randomize_parameters( ) self.transform_parameters[param] = dist.sample(sample_shape=(batch_size,)) - def apply_transform(self, selected_samples: torch.Tensor, sample_rate: int = None): + def apply_transform( + self, + selected_samples: torch.Tensor, + sample_rate: int = None, + targets: torch.Tensor = None, + target_rate: int = None, + ): batch_size, num_channels, num_samples = selected_samples.shape if sample_rate is None: @@ -123,6 +141,9 @@ def apply_transform(self, selected_samples: torch.Tensor, sample_rate: int = Non 10 ** (self.transform_parameters["snr_in_db"].unsqueeze(dim=-1) / 20) ) - return selected_samples + noise_rms.unsqueeze(-1) * noise.view( - batch_size, 1, num_samples - ).expand(-1, num_channels, -1) + return ( + selected_samples + + noise_rms.unsqueeze(-1) + * noise.view(batch_size, 1, num_samples).expand(-1, num_channels, -1), + targets, + ) diff --git a/torch_audiomentations/augmentations/gain.py b/torch_audiomentations/augmentations/gain.py index 278aacc2..ad49709a 100644 --- a/torch_audiomentations/augmentations/gain.py +++ b/torch_audiomentations/augmentations/gain.py @@ -25,15 +25,26 @@ def __init__( p: float = 0.5, p_mode: typing.Optional[str] = None, sample_rate: typing.Optional[int] = None, + target_rate: typing.Optional[int] = None, ): - super().__init__(mode, p, p_mode, sample_rate) + super().__init__( + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, + ) self.min_gain_in_db = min_gain_in_db self.max_gain_in_db = max_gain_in_db if self.min_gain_in_db >= self.max_gain_in_db: raise ValueError("max_gain_in_db must be higher than min_gain_in_db") def randomize_parameters( - self, selected_samples, sample_rate: typing.Optional[int] = None + self, + selected_samples: torch.Tensor, + sample_rate: int = None, + targets: torch.Tensor = None, + target_rate: int = None, ): distribution = torch.distributions.Uniform( low=torch.tensor( @@ -53,5 +64,11 @@ def randomize_parameters( .unsqueeze(1) ) - def apply_transform(self, selected_samples, sample_rate: typing.Optional[int] = None): - return selected_samples * self.transform_parameters["gain_factors"] + def apply_transform( + self, + selected_samples: torch.Tensor, + sample_rate: int = None, + targets: torch.Tensor = None, + target_rate: int = None, + ): + return selected_samples * self.transform_parameters["gain_factors"], targets diff --git a/torch_audiomentations/augmentations/high_pass_filter.py b/torch_audiomentations/augmentations/high_pass_filter.py index 07f023fb..1b0a7c5a 100644 --- a/torch_audiomentations/augmentations/high_pass_filter.py +++ b/torch_audiomentations/augmentations/high_pass_filter.py @@ -19,6 +19,7 @@ def __init__( p: float = 0.5, p_mode: str = None, sample_rate: int = None, + target_rate: int = None, ): """ :param min_cutoff_freq: Minimum cutoff frequency in hertz @@ -27,11 +28,30 @@ def __init__( :param p: :param p_mode: :param sample_rate: + :param target_rate: """ - super().__init__(min_cutoff_freq, max_cutoff_freq, mode, p, p_mode, sample_rate) - def apply_transform(self, selected_samples: torch.Tensor, sample_rate: int = None): - low_pass_filtered_samples = super().apply_transform( - selected_samples.clone(), sample_rate + super().__init__( + min_cutoff_freq, + max_cutoff_freq, + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, ) - return selected_samples - low_pass_filtered_samples + + def apply_transform( + self, + selected_samples: torch.Tensor, + sample_rate: int = None, + targets: torch.Tensor = None, + target_rate: int = None, + ): + low_pass_filtered_samples, low_pass_filtered_targets = super().apply_transform( + selected_samples.clone(), + sample_rate, + targets=targets.clone() if targets is not None else None, + target_rate=target_rate, + ) + return selected_samples - low_pass_filtered_samples, low_pass_filtered_targets diff --git a/torch_audiomentations/augmentations/impulse_response.py b/torch_audiomentations/augmentations/impulse_response.py index e0f5d4e5..ee12d1b5 100644 --- a/torch_audiomentations/augmentations/impulse_response.py +++ b/torch_audiomentations/augmentations/impulse_response.py @@ -30,6 +30,7 @@ def __init__( p: float = 0.5, p_mode: str = None, sample_rate: int = None, + target_rate: int = None, ): """ :param ir_paths: Either a path to a folder with audio files or a list of paths to audio files. @@ -43,8 +44,16 @@ def __init__( :param p: :param p_mode: :param sample_rate: + :param target_rate: """ - super().__init__(mode, p, p_mode, sample_rate) + + super().__init__( + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, + ) if isinstance(ir_paths, (list, tuple, set)): # TODO: check that one can read audio files @@ -61,7 +70,13 @@ def __init__( self.convolve_mode = convolve_mode self.compensate_for_propagation_delay = compensate_for_propagation_delay - def randomize_parameters(self, selected_samples, sample_rate: int = None): + def randomize_parameters( + self, + selected_samples: torch.Tensor, + sample_rate: int = None, + targets: torch.Tensor = None, + target_rate: int = None, + ): batch_size, _, _ = selected_samples.shape @@ -77,7 +92,13 @@ def randomize_parameters(self, selected_samples, sample_rate: int = None): self.transform_parameters["ir_paths"] = random_ir_paths - def apply_transform(self, selected_samples, sample_rate: int = None): + def apply_transform( + self, + selected_samples: torch.Tensor, + sample_rate: int = None, + targets: torch.Tensor = None, + target_rate: int = None, + ): batch_size, num_channels, num_samples = selected_samples.shape @@ -90,7 +111,6 @@ def apply_transform(self, selected_samples, sample_rate: int = None): if self.compensate_for_propagation_delay: propagation_delays = ir.abs().argmax(dim=2, keepdim=False)[:, 0] - convolved_samples = torch.stack( [ convolved_sample[ @@ -103,7 +123,9 @@ def apply_transform(self, selected_samples, sample_rate: int = None): dim=0, ) - return convolved_samples + # FIXME should we compensate targets as well? + return convolved_samples, targets else: - return convolved_samples[..., :num_samples] + # FIXME should we strip targets as well? + return convolved_samples[..., :num_samples], targets diff --git a/torch_audiomentations/augmentations/low_pass_filter.py b/torch_audiomentations/augmentations/low_pass_filter.py index 1f6812e2..ace821e6 100644 --- a/torch_audiomentations/augmentations/low_pass_filter.py +++ b/torch_audiomentations/augmentations/low_pass_filter.py @@ -21,6 +21,7 @@ def __init__( p: float = 0.5, p_mode: str = None, sample_rate: int = None, + target_rate: int = None, ): """ :param min_cutoff_freq: Minimum cutoff frequency in hertz @@ -30,7 +31,13 @@ def __init__( :param p_mode: :param sample_rate: """ - super().__init__(mode, p, p_mode, sample_rate) + super().__init__( + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, + ) self.min_cutoff_freq = min_cutoff_freq self.max_cutoff_freq = max_cutoff_freq @@ -38,7 +45,11 @@ def __init__( raise ValueError("min_cutoff_freq must not be greater than max_cutoff_freq") def randomize_parameters( - self, selected_samples: torch.Tensor, sample_rate: int = None + self, + selected_samples: torch.Tensor, + sample_rate: int = None, + targets: torch.Tensor = None, + target_rate: int = None, ): """ :params selected_samples: (batch_size, num_channels, num_samples) @@ -67,7 +78,13 @@ def randomize_parameters( dist.sample(sample_shape=(batch_size,)) ) - def apply_transform(self, selected_samples: torch.Tensor, sample_rate: int = None): + def apply_transform( + self, + selected_samples: torch.Tensor, + sample_rate: int = None, + targets: torch.Tensor = None, + target_rate: int = None, + ): batch_size, num_channels, num_samples = selected_samples.shape if sample_rate is None: @@ -82,4 +99,4 @@ def apply_transform(self, selected_samples: torch.Tensor, sample_rate: int = Non selected_samples[i], cutoffs_as_fraction_of_sample_rate[i].item() ) - return selected_samples + return selected_samples, targets diff --git a/torch_audiomentations/augmentations/peak_normalization.py b/torch_audiomentations/augmentations/peak_normalization.py index 804b7f52..0dadfcb8 100644 --- a/torch_audiomentations/augmentations/peak_normalization.py +++ b/torch_audiomentations/augmentations/peak_normalization.py @@ -25,13 +25,24 @@ def __init__( p: float = 0.5, p_mode: typing.Optional[str] = None, sample_rate: typing.Optional[int] = None, + target_rate: typing.Optional[int] = None, ): - super().__init__(mode, p, p_mode, sample_rate) + super().__init__( + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, + ) assert apply_to in ("all", "only_too_loud_sounds") self.apply_to = apply_to def randomize_parameters( - self, selected_samples, sample_rate: typing.Optional[int] = None + self, + selected_samples: torch.Tensor, + sample_rate: int = None, + targets: torch.Tensor = None, + target_rate: int = None, ): # Compute the most extreme value of each multichannel audio snippet in the batch most_extreme_values, _ = torch.max(torch.abs(selected_samples), dim=-1) @@ -55,9 +66,15 @@ def randomize_parameters( most_extreme_values[self.transform_parameters["selector"]], (-1, 1, 1) ) - def apply_transform(self, selected_samples, sample_rate: typing.Optional[int] = None): + def apply_transform( + self, + selected_samples: torch.Tensor, + sample_rate: int = None, + targets: torch.Tensor = None, + target_rate: int = None, + ): if "divisors" in self.transform_parameters: selected_samples[ self.transform_parameters["selector"] ] /= self.transform_parameters["divisors"] - return selected_samples + return selected_samples, targets diff --git a/torch_audiomentations/augmentations/pitch_shift.py b/torch_audiomentations/augmentations/pitch_shift.py index b8f27641..bde673bc 100644 --- a/torch_audiomentations/augmentations/pitch_shift.py +++ b/torch_audiomentations/augmentations/pitch_shift.py @@ -16,12 +16,13 @@ class PitchShift(BaseWaveformTransform): def __init__( self, - sample_rate: int, min_transpose_semitones: float = -4.0, max_transpose_semitones: float = 4.0, mode: str = "per_example", p: float = 0.5, p_mode: str = None, + sample_rate: int = None, + target_rate: int = None, ): """ :param sample_rate: @@ -30,8 +31,15 @@ def __init__( :param mode: ``per_example``, ``per_channel``, or ``per_batch``. Default ``per_example``. :param p: :param p_mode: + :param target_rate: """ - super().__init__(mode, p, p_mode, sample_rate) + super().__init__( + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, + ) if min_transpose_semitones > max_transpose_semitones: raise ValueError("max_transpose_semitones must be > min_transpose_semitones") @@ -51,7 +59,11 @@ def __init__( self._mode = mode def randomize_parameters( - self, selected_samples: torch.Tensor, sample_rate: int = None + self, + selected_samples: torch.Tensor, + sample_rate: int = None, + targets: torch.Tensor = None, + target_rate: int = None, ): """ :param selected_samples: (batch_size, num_channels, num_samples) @@ -75,7 +87,13 @@ def randomize_parameters( elif self._mode == "per_batch": self.transform_parameters["transpositions"] = choices(self._fast_shifts, k=1) - def apply_transform(self, selected_samples: torch.Tensor, sample_rate: int = None): + def apply_transform( + self, + selected_samples: torch.Tensor, + sample_rate: int = None, + targets: torch.Tensor = None, + target_rate: int = None, + ): """ :param selected_samples: (batch_size, num_channels, num_samples) :param sample_rate: @@ -105,10 +123,13 @@ def apply_transform(self, selected_samples: torch.Tensor, sample_rate: int = Non sample_rate, )[0][0] elif self._mode == "per_batch": - return pitch_shift( - selected_samples, - self.transform_parameters["transpositions"][0], - sample_rate, + return ( + pitch_shift( + selected_samples, + self.transform_parameters["transpositions"][0], + sample_rate, + ), + targets, ) - return selected_samples + return selected_samples, targets diff --git a/torch_audiomentations/augmentations/polarity_inversion.py b/torch_audiomentations/augmentations/polarity_inversion.py index d1d1d260..4894afc3 100644 --- a/torch_audiomentations/augmentations/polarity_inversion.py +++ b/torch_audiomentations/augmentations/polarity_inversion.py @@ -1,4 +1,5 @@ import typing +import torch from ..core.transforms_interface import BaseWaveformTransform @@ -23,8 +24,21 @@ def __init__( p: float = 0.5, p_mode: typing.Optional[str] = None, sample_rate: typing.Optional[int] = None, + target_rate: typing.Optional[int] = None, ): - super().__init__(mode, p, p_mode, sample_rate) + super().__init__( + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, + ) - def apply_transform(self, selected_samples, sample_rate: typing.Optional[int] = None): - return -selected_samples + def apply_transform( + self, + selected_samples: torch.Tensor, + sample_rate: int = None, + targets: torch.Tensor = None, + target_rate: int = None, + ): + return -selected_samples, targets diff --git a/torch_audiomentations/augmentations/shift.py b/torch_audiomentations/augmentations/shift.py index 97b5a17f..f32c562e 100644 --- a/torch_audiomentations/augmentations/shift.py +++ b/torch_audiomentations/augmentations/shift.py @@ -61,6 +61,7 @@ def __init__( p: float = 0.5, p_mode: typing.Optional[str] = None, sample_rate: typing.Optional[int] = None, + target_rate: typing.Optional[int] = None, ): """ @@ -78,8 +79,15 @@ def __init__( :param p: :param p_mode: :param sample_rate: + :param target_rate: """ - super().__init__(mode, p, p_mode, sample_rate) + super().__init__( + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, + ) self.min_shift = min_shift self.max_shift = max_shift self.shift_unit = shift_unit @@ -90,17 +98,24 @@ def __init__( raise ValueError('shift_unit must be "samples", "fraction" or "seconds"') def randomize_parameters( - self, selected_samples, sample_rate: typing.Optional[int] = None + self, + selected_samples: torch.Tensor, + sample_rate: int = None, + targets: torch.Tensor = None, + target_rate: int = None, ): if self.shift_unit == "samples": min_shift_in_samples = self.min_shift max_shift_in_samples = self.max_shift + elif self.shift_unit == "fraction": min_shift_in_samples = int(round(self.min_shift * selected_samples.shape[-1])) max_shift_in_samples = int(round(self.max_shift * selected_samples.shape[-1])) + elif self.shift_unit == "seconds": min_shift_in_samples = int(round(self.min_shift * sample_rate)) max_shift_in_samples = int(round(self.max_shift * sample_rate)) + else: raise ValueError("Invalid shift_unit") @@ -122,6 +137,7 @@ def randomize_parameters( dtype=torch.int32, device=selected_samples.device, ) + else: self.transform_parameters["num_samples_to_shift"] = torch.randint( low=min_shift_in_samples, @@ -131,12 +147,38 @@ def randomize_parameters( device=selected_samples.device, ) - def apply_transform(self, selected_samples, sample_rate: typing.Optional[int] = None): - r = self.transform_parameters["num_samples_to_shift"] + def apply_transform( + self, + selected_samples: torch.Tensor, + sample_rate: int = None, + targets: torch.Tensor = None, + target_rate: int = None, + ): + + num_samples_to_shift = self.transform_parameters["num_samples_to_shift"] + # Select fastest implementation based on device shift = shift_gpu if selected_samples.device.type == "cuda" else shift_cpu - return shift(selected_samples, r, self.rollover) + shifted_samples = shift(selected_samples, num_samples_to_shift, self.rollover) + + if targets is None: + shifted_targets = targets + else: + # FIXME corner case where target_rate is missing + # FIXME corner case where target is not correlated with the input length + num_frames_to_shift = int( + round(target_rate * num_samples_to_shift / sample_rate) + ) + shifted_targets = shift( + targets.transpose(-2, -1), num_frames_to_shift, self.rollover + ).transpose(-2, -1) + + return shifted_samples, shifted_targets def is_sample_rate_required(self) -> bool: # Sample rate is required only if shift_unit is "seconds" return self.shift_unit == "seconds" + + def is_target_rate_required(self) -> bool: + # FIXME should be True only when targets is passed to apply_transform + return self.requires_target_rate diff --git a/torch_audiomentations/augmentations/shuffle_channels.py b/torch_audiomentations/augmentations/shuffle_channels.py index 7f06eea5..b016d2a9 100644 --- a/torch_audiomentations/augmentations/shuffle_channels.py +++ b/torch_audiomentations/augmentations/shuffle_channels.py @@ -24,11 +24,22 @@ def __init__( p: float = 0.5, p_mode: typing.Optional[str] = None, sample_rate: typing.Optional[int] = None, + target_rate: typing.Optional[int] = None, ): - super().__init__(mode, p, p_mode, sample_rate) + super().__init__( + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, + ) def randomize_parameters( - self, selected_samples, sample_rate: typing.Optional[int] = None + self, + selected_samples: torch.Tensor, + sample_rate: int = None, + targets: torch.Tensor = None, + target_rate: int = None, ): batch_size = selected_samples.shape[0] num_channels = selected_samples.shape[1] @@ -40,15 +51,26 @@ def randomize_parameters( permutations[i] = torch.randperm(num_channels, device=selected_samples.device) self.transform_parameters["permutations"] = permutations - def apply_transform(self, selected_samples, sample_rate: typing.Optional[int] = None): + def apply_transform( + self, + selected_samples: torch.Tensor, + sample_rate: int = None, + targets: torch.Tensor = None, + target_rate: int = None, + ): + if selected_samples.shape[1] == 1: warnings.warn( "Mono audio was passed to ShuffleChannels - there are no channels to shuffle." " The input will be returned unchanged." ) - return selected_samples + return selected_samples, targets + for i in range(selected_samples.size(0)): selected_samples[i] = selected_samples[ i, self.transform_parameters["permutations"][i] ] - return selected_samples + if targets is not None: + targets[i] = targets[i, self.transform_parameters["permutations"][i]] + + return selected_samples, targets diff --git a/torch_audiomentations/augmentations/time_inversion.py b/torch_audiomentations/augmentations/time_inversion.py index 53191d71..c4c16542 100644 --- a/torch_audiomentations/augmentations/time_inversion.py +++ b/torch_audiomentations/augmentations/time_inversion.py @@ -21,6 +21,7 @@ def __init__( p: float = 0.5, p_mode: str = None, sample_rate: int = None, + target_rate: int = None, ): """ :param mode: @@ -28,12 +29,32 @@ def __init__( :param p_mode: :param sample_rate: """ - super().__init__(mode, p, p_mode, sample_rate) + super().__init__( + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, + ) + + def apply_transform( + self, + selected_samples: torch.Tensor, + sample_rate: int = None, + targets: torch.Tensor = None, + target_rate: int = None, + ): - def apply_transform(self, selected_samples: torch.Tensor, sample_rate: int = None): # torch.flip() is supposed to be slower than np.flip() # An alternative is to use advanced indexing: https://github.com/pytorch/pytorch/issues/16424 # reverse_index = torch.arange(selected_samples.size(-1) - 1, -1, -1).to(selected_samples.device) # transformed_samples = selected_samples[..., reverse_index] # return transformed_samples - return torch.flip(selected_samples, dims=(-1,)) + + flipped_samples = torch.flip(selected_samples, dims=(-1,)) + if targets is None: + flipped_targets = targets + else: + flipped_targets = torch.flip(targets, dims=(-2,)) + + return flipped_samples, flipped_targets diff --git a/torch_audiomentations/core/composition.py b/torch_audiomentations/core/composition.py index e37bd07f..e360d6c9 100644 --- a/torch_audiomentations/core/composition.py +++ b/torch_audiomentations/core/composition.py @@ -63,7 +63,13 @@ def supported_modes(self) -> set: class Compose(BaseCompose): - def forward(self, samples, sample_rate: typing.Optional[int] = None): + def forward( + self, + samples, + sample_rate: typing.Optional[int] = None, + targets=None, + target_rate: typing.Optional[int] = None, + ): if random.random() < self.p: transform_indexes = list(range(len(self.transforms))) if self.shuffle: @@ -71,8 +77,14 @@ def forward(self, samples, sample_rate: typing.Optional[int] = None): for i in transform_indexes: tfm = self.transforms[i] if isinstance(tfm, BaseWaveformTransform): - samples = self.transforms[i](samples, sample_rate) + if targets is None: + samples = self.transforms[i](samples, sample_rate) + else: + samples, targets = self.transforms[i]( + samples, sample_rate, targets=targets, target_rate=target_rate + ) else: + # FIXME: add support for targets? samples = self.transforms[i](samples) return samples @@ -131,7 +143,13 @@ def randomize_parameters(self): random.sample(self.all_transforms_indexes, num_transforms_to_apply) ) - def forward(self, samples, sample_rate: typing.Optional[int] = None): + def forward( + self, + samples, + sample_rate: typing.Optional[int] = None, + targets=None, + target_rate: typing.Optional[int] = None, + ): if random.random() < self.p: if not self.are_parameters_frozen: @@ -140,8 +158,14 @@ def forward(self, samples, sample_rate: typing.Optional[int] = None): for i in self.transform_indexes: tfm = self.transforms[i] if isinstance(tfm, BaseWaveformTransform): - samples = self.transforms[i](samples, sample_rate) + if targets is None: + samples = self.transforms[i](samples, sample_rate) + else: + samples, targets = self.transforms[i]( + samples, sample_rate, targets=targets, target_rate=target_rate + ) else: + # FIXME: add support for targets? samples = self.transforms[i](samples) return samples diff --git a/torch_audiomentations/core/transforms_interface.py b/torch_audiomentations/core/transforms_interface.py index 7cc90ca8..3908a88e 100644 --- a/torch_audiomentations/core/transforms_interface.py +++ b/torch_audiomentations/core/transforms_interface.py @@ -20,9 +20,12 @@ class ModeNotSupportedException(Exception): class BaseWaveformTransform(torch.nn.Module): + supports_multichannel = True supported_modes = {"per_batch", "per_example", "per_channel"} requires_sample_rate = True + requires_targets = False + requires_target_rate = False def __init__( self, @@ -30,6 +33,7 @@ def __init__( p: float = 0.5, p_mode: typing.Optional[str] = None, sample_rate: typing.Optional[int] = None, + target_rate: typing.Optional[int] = None, ): """ @@ -49,6 +53,8 @@ def __init__( with different parameters. Default value: Same as mode. :param sample_rate: sample_rate can be set either here or when calling the transform. + :param target_rate: target_rate can be set either here or when + calling the transform. """ super().__init__() assert 0.0 <= p <= 1.0 @@ -58,6 +64,7 @@ def __init__( if self.p_mode is None: self.p_mode = self.mode self.sample_rate = sample_rate + self.target_rate = target_rate # Check validity of mode/p_mode combination if self.mode not in self.supported_modes: @@ -87,9 +94,19 @@ def p(self, p): # Update the Bernoulli distribution accordingly self.bernoulli_distribution = Bernoulli(self._p) - def forward(self, samples, sample_rate: typing.Optional[int] = None): + def forward( + self, + samples, + sample_rate: typing.Optional[int] = None, + targets=None, + target_rate: typing.Optional[int] = None, + ): + if not self.training: - return samples + if targets is None: + return samples + else: + return samples, targets if len(samples) == 0: warnings.warn( @@ -104,8 +121,10 @@ def forward(self, samples, sample_rate: typing.Optional[int] = None): " audio is mono, you can use a shape like [batch_size, 1, num_samples]." ) + batch_size, num_channels, num_samples = samples.shape + if is_multichannel(samples): - if samples.shape[1] > samples.shape[2]: + if num_channels > num_samples: warnings.warn( "Multichannel audio must have channels first, not channels last. In" " other words, the shape must be (batch size, channels, samples), not" @@ -122,15 +141,49 @@ def forward(self, samples, sample_rate: typing.Optional[int] = None): if sample_rate is None and self.is_sample_rate_required(): raise RuntimeError("sample_rate is required") + if self.is_targets_required(): + + if targets is None: + raise RuntimeError("targets is required") + + if len(targets.shape) != 4: + raise RuntimeError( + "torch-audiomentations expects target tensors to be four-dimensional, with" + " dimension ordering like [batch_size, num_channels, num_frames, num_classes]." + " If your target is binary, you can use a shape like [batch_size, num_channels, num_frames, 1]." + " If your target is for the whole channel, you can use a shape like [batch_size, num_channels, 1, num_classes]." + ) + + batch_size_, num_channels_, num_frames, num_classes = targets.shape + + if batch_size_ != batch_size: + raise RuntimeError( + f"samples ({batch_size}) and target ({batch_size_}) batch sizes must be equal." + ) + if num_channels != num_channels_: + raise RuntimeError( + f"samples ({num_channels}) and target ({num_channels_}) number of channels must be equal." + ) + + target_rate = target_rate or self.target_rate + if target_rate is None and self.is_target_rate_required(): + # IDEA: automatically estimate target_rate based on samples, sample_rate, and targets + raise RuntimeError("target_rate is required") + if not self.are_parameters_frozen: + if self.p_mode == "per_example": - p_sample_size = samples.shape[0] + p_sample_size = batch_size + elif self.p_mode == "per_channel": - p_sample_size = samples.shape[0] * samples.shape[1] + p_sample_size = batch_size * num_channels + elif self.p_mode == "per_batch": p_sample_size = 1 + else: raise Exception("Invalid mode") + self.transform_parameters = { "should_apply": self.bernoulli_distribution.sample( sample_shape=(p_sample_size,) @@ -138,103 +191,255 @@ def forward(self, samples, sample_rate: typing.Optional[int] = None): } if self.transform_parameters["should_apply"].any(): + cloned_samples = samples.clone() + if targets is None: + cloned_targets = None + selected_targets = None + else: + cloned_targets = targets.clone() + if self.p_mode == "per_channel": - batch_size = cloned_samples.shape[0] - num_channels = cloned_samples.shape[1] + cloned_samples = cloned_samples.reshape( - batch_size * num_channels, 1, cloned_samples.shape[2] + batch_size * num_channels, 1, num_samples ) selected_samples = cloned_samples[ self.transform_parameters["should_apply"] ] + if targets is not None: + cloned_targets = cloned_targets.reshape( + batch_size * num_channels, 1, num_frames, num_classes + ) + selected_targets = cloned_targets[ + self.transform_parameters["should_apply"] + ] + if not self.are_parameters_frozen: - self.randomize_parameters(selected_samples, sample_rate) + self.randomize_parameters( + selected_samples, + sample_rate=sample_rate, + targets=selected_targets, + target_rate=target_rate, + ) + + perturbed_samples, perturbed_targets = self.apply_transform( + selected_samples, + sample_rate=sample_rate, + targets=selected_targets, + target_rate=target_rate, + ) cloned_samples[ self.transform_parameters["should_apply"] - ] = self.apply_transform(selected_samples, sample_rate) - + ] = perturbed_samples cloned_samples = cloned_samples.reshape( - batch_size, num_channels, cloned_samples.shape[2] + batch_size, num_channels, num_samples ) - return cloned_samples + if targets is None: + return cloned_samples + + else: + cloned_targets[ + self.transform_parameters["should_apply"] + ] = perturbed_targets + cloned_targets = cloned_targets.reshape( + batch_size, num_channels, num_frames, num_classes + ) + return cloned_samples, cloned_targets elif self.p_mode == "per_example": + selected_samples = cloned_samples[ self.transform_parameters["should_apply"] ] + if targets is not None: + selected_targets = cloned_targets[ + self.transform_parameters["should_apply"] + ] + if self.mode == "per_example": + if not self.are_parameters_frozen: - self.randomize_parameters(selected_samples, sample_rate) + self.randomize_parameters( + selected_samples, + sample_rate, + targets=selected_targets, + target_rate=target_rate, + ) + + perturbed_samples, perturbed_targets = self.apply_transform( + selected_samples, + sample_rate=sample_rate, + targets=selected_targets, + target_rate=target_rate, + ) cloned_samples[ self.transform_parameters["should_apply"] - ] = self.apply_transform(selected_samples, sample_rate) - return cloned_samples + ] = perturbed_samples + + if targets is None: + return cloned_samples + + else: + cloned_targets[ + self.transform_parameters["should_apply"] + ] = perturbed_targets + return cloned_samples, cloned_targets + elif self.mode == "per_channel": - batch_size = selected_samples.shape[0] - num_channels = selected_samples.shape[1] - selected_samples = selected_samples.reshape( - batch_size * num_channels, 1, selected_samples.shape[2] - ) - if not self.are_parameters_frozen: - self.randomize_parameters(selected_samples, sample_rate) + b, c, s = selected_samples.shape - perturbed_samples = self.apply_transform( - selected_samples, sample_rate - ) - perturbed_samples = perturbed_samples.reshape( - batch_size, num_channels, selected_samples.shape[2] + selected_samples = selected_samples.reshape(b * c, 1, s) + + if targets is not None: + selected_targets = selected_targets.reshape( + b * c, 1, num_frames, num_classes + ) + + if not self.are_parameters_frozen: + self.randomize_parameters( + selected_samples, + sample_rate, + targets=selected_targets, + target_rate=target_rate, + ) + + perturbed_samples, perturbed_targets = self.apply_transform( + selected_samples, + sample_rate=sample_rate, + targets=selected_targets, + target_rate=target_rate, ) + perturbed_samples = perturbed_samples.reshape(b, c, s) cloned_samples[ self.transform_parameters["should_apply"] ] = perturbed_samples - return cloned_samples + + if targets is None: + return cloned_samples + + else: + perturbed_targets = perturbed_targets.reshape( + b, c, num_frames, num_classes + ) + cloned_targets[ + self.transform_parameters["should_apply"] + ] = perturbed_targets + return cloned_samples, cloned_targets + else: raise Exception("Invalid mode/p_mode combination") + elif self.p_mode == "per_batch": + if self.mode == "per_batch": - batch_size = cloned_samples.shape[0] - num_channels = cloned_samples.shape[1] + cloned_samples = cloned_samples.reshape( - 1, batch_size * num_channels, cloned_samples.shape[2] + 1, batch_size * num_channels, num_samples ) - if not self.are_parameters_frozen: - self.randomize_parameters(cloned_samples, sample_rate) + if targets is not None: + cloned_targets = cloned_targets.reshape( + 1, batch_size * num_channels, num_frames, num_classes + ) - perturbed_samples = self.apply_transform(cloned_samples, sample_rate) + if not self.are_parameters_frozen: + self.randomize_parameters( + cloned_samples, + sample_rate, + targets=cloned_targets, + target_rate=target_rate, + ) + + perturbed_samples, perturbed_targets = self.apply_transform( + cloned_samples, + sample_rate, + targets=cloned_targets, + target_rate=target_rate, + ) perturbed_samples = perturbed_samples.reshape( - batch_size, num_channels, cloned_samples.shape[2] + batch_size, num_channels, num_samples ) - return perturbed_samples + + if targets is None: + return perturbed_samples + + else: + perturbed_targets = perturbed_targets.reshape( + batch_size, num_channels, num_frames, num_classes + ) + return perturbed_samples, perturbed_targets + elif self.mode == "per_example": + if not self.are_parameters_frozen: - self.randomize_parameters(cloned_samples, sample_rate) - return self.apply_transform(cloned_samples, sample_rate) + self.randomize_parameters( + cloned_samples, + sample_rate, + targets=cloned_targets, + target_rate=target_rate, + ) + + perturbed_samples, perturbed_targets = self.apply_transform( + cloned_samples, + sample_rate, + targets=cloned_targets, + target_rate=target_rate, + ) + + if targets is None: + return perturbed_samples + + else: + return perturbed_samples, perturbed_targets + elif self.mode == "per_channel": - batch_size = cloned_samples.shape[0] - num_channels = cloned_samples.shape[1] + cloned_samples = cloned_samples.reshape( - batch_size * num_channels, 1, cloned_samples.shape[2] + batch_size * num_channels, 1, num_samples ) - if not self.are_parameters_frozen: - self.randomize_parameters(cloned_samples, sample_rate) + if targets is not None: + cloned_targets = cloned_targets.reshape( + batch_size * num_channels, 1, num_frames, num_classes + ) - perturbed_samples = self.apply_transform(cloned_samples, sample_rate) + if not self.are_parameters_frozen: + self.randomize_parameters( + cloned_samples, + sample_rate, + targets=cloned_targets, + target_rate=target_rate, + ) + + perturbed_samples, perturbed_targets = self.apply_transform( + cloned_samples, + sample_rate, + targets=cloned_targets, + target_rate=target_rate, + ) perturbed_samples = perturbed_samples.reshape( - batch_size, num_channels, cloned_samples.shape[2] + batch_size, num_channels, num_samples ) - return perturbed_samples + + if targets is None: + return perturbed_samples + + else: + perturbed_targets = perturbed_targets.reshape( + batch_size, num_channels, num_frames, num_classes + ) + + return perturbed_samples, perturbed_targets else: raise Exception("Invalid mode") else: @@ -248,11 +453,22 @@ def _forward_unimplemented(self, *inputs) -> None: pass def randomize_parameters( - self, selected_samples, sample_rate: typing.Optional[int] = None + self, + selected_samples, + sample_rate: typing.Optional[int] = None, + targets=None, + target_rate: typing.Optional[int] = None, ): pass - def apply_transform(self, selected_samples, sample_rate: typing.Optional[int] = None): + def apply_transform( + self, + selected_samples, + sample_rate: typing.Optional[int] = None, + targets=None, + target_rate: typing.Optional[int] = None, + ): + raise NotImplementedError() def serialize_parameters(self): @@ -276,3 +492,9 @@ def unfreeze_parameters(self): def is_sample_rate_required(self) -> bool: return self.requires_sample_rate + + def is_targets_required(self) -> bool: + return self.requires_targets + + def is_target_rate_required(self) -> bool: + return self.requires_target_rate From 84ea872de9d01bf593602e9dd7d640157be8f5eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Thu, 24 Mar 2022 08:57:44 +0100 Subject: [PATCH 02/15] fix: fix missing targets handling --- torch_audiomentations/core/transforms_interface.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_audiomentations/core/transforms_interface.py b/torch_audiomentations/core/transforms_interface.py index 3908a88e..158a9d5a 100644 --- a/torch_audiomentations/core/transforms_interface.py +++ b/torch_audiomentations/core/transforms_interface.py @@ -141,10 +141,10 @@ def forward( if sample_rate is None and self.is_sample_rate_required(): raise RuntimeError("sample_rate is required") - if self.is_targets_required(): + if targets is None and self.is_targets_required(): + raise RuntimeError("targets is required") - if targets is None: - raise RuntimeError("targets is required") + if targets is not None: if len(targets.shape) != 4: raise RuntimeError( From e4f68726236ad34ca627052adb3bfb45a7436b62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Wed, 23 Mar 2022 14:10:38 +0100 Subject: [PATCH 03/15] feat: add support for batches in Audio.rms_normalize --- torch_audiomentations/utils/io.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torch_audiomentations/utils/io.py b/torch_audiomentations/utils/io.py index ad829666..dbe2a6ba 100644 --- a/torch_audiomentations/utils/io.py +++ b/torch_audiomentations/utils/io.py @@ -84,16 +84,16 @@ def rms_normalize(samples: Tensor) -> Tensor: Parameters ---------- - samples : (channel, time) Tensor - Single or multichannel samples + samples : (..., time) Tensor + Single (or multichannel) samples or batch of samples Returns ------- - samples: (channel, time) Tensor + samples: (..., time) Tensor Power-normalized samples """ - rms = samples.square().mean(dim=1).sqrt() - return (samples.t() / (rms + 1e-8)).t() + rms = samples.square().mean(dim=-1, keepdim=True).sqrt() + return samples / (rms + 1e-8) @staticmethod def get_audio_metadata(file_path) -> tuple: From 2ff22c8f09c9eac138e70dab457e348028df009d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Thu, 24 Mar 2022 09:25:47 +0100 Subject: [PATCH 04/15] wip: add Mix transform --- torch_audiomentations/augmentations/mix.py | 110 +++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 torch_audiomentations/augmentations/mix.py diff --git a/torch_audiomentations/augmentations/mix.py b/torch_audiomentations/augmentations/mix.py new file mode 100644 index 00000000..733a6c12 --- /dev/null +++ b/torch_audiomentations/augmentations/mix.py @@ -0,0 +1,110 @@ +from typing import Optional +import torch + +from ..core.transforms_interface import BaseWaveformTransform +from ..utils.dsp import calculate_rms +from ..utils.io import Audio + + +class Mix(BaseWaveformTransform): + """ + Create a new sample by mixing it with another random sample from the same batch + + Signal-to-noise ratio (where "noise" is the second random sample) is selected + randomly between `min_snr_in_db` and `max_snr_in_db`. + + `mix_target` controls how resulting targets are generated. It can be one of + "original" (targets are those of the original sample) or "union" (targets are the + union of original and overlapping targets) + + """ + + supports_multichannel = True + supported_modes = {"per_example", "per_channel"} + requires_sample_rate = False + requires_targets = False + requires_target_rate = False + + def __init__( + self, + min_snr_in_db: float = 0.0, + max_snr_in_db: float = 5.0, + mix_target: str = "union", + mode: str = "per_example", + p: float = 0.5, + p_mode: str = None, + sample_rate: int = None, + target_rate: int = None, + ): + super().__init__( + mode=mode, + p=p, + p_mode=p_mode, + sample_rate=sample_rate, + target_rate=target_rate, + ) + self.min_snr_in_db = min_snr_in_db + self.max_snr_in_db = max_snr_in_db + if self.min_snr_in_db > self.max_snr_in_db: + raise ValueError("min_snr_in_db must not be greater than max_snr_in_db") + + # TODO: make it configuration via a "string" + self.mix_target = mix_target + self._mix_target = lambda target, background_target, snr: torch.maximize( + target, background_target + ) + + def randomize_parameters( + self, + selected_samples, + sample_rate: Optional[int] = None, + targets=None, + target_rate: Optional[int] = None, + ): + + batch_size, num_channels, num_samples = selected_samples.shape + snr_distribution = torch.distributions.Uniform( + low=torch.tensor( + self.min_snr_in_db, dtype=torch.float32, device=selected_samples.device, + ), + high=torch.tensor( + self.max_snr_in_db, dtype=torch.float32, device=selected_samples.device, + ), + validate_args=True, + ) + + # randomize SNRs + self.transform_parameters["snr_in_db"] = snr_distribution.sample( + sample_shape=(batch_size,) + ) + + # randomize index of second sample + self.transform_parameters["sample_idx"] = torch.randint( + 0, batch_size, (batch_size,), device=selected_samples.device, + ) + + def apply_transform( + self, + selected_samples: torch.Tensor, + sample_rate: int = None, + targets: torch.Tensor = None, + target_rate: int = None, + ): + + snr = self.transform_parameters["snr_in_db"] + idx = self.transform_parameters["sample_idx"] + + background_samples = Audio.rms_normalize(selected_samples[idx]) + background_rms = calculate_rms(selected_samples) / ( + 10 ** (snr.unsqueeze(dim=-1) / 20) + ) + + perturbed_samples = ( + selected_samples + background_rms.unsqueeze(-1) * background_samples + ) + if targets is None: + return perturbed_samples + + background_targets = targets[idx] + perturbed_targets = self._mix_target(targets, background_targets, snr) + return perturbed_samples, perturbed_targets From 94a673696fef8935212ef4ca4590208573a0812d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Thu, 24 Mar 2022 09:39:24 +0100 Subject: [PATCH 05/15] fix: honor "mix_target" option --- torch_audiomentations/augmentations/mix.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/torch_audiomentations/augmentations/mix.py b/torch_audiomentations/augmentations/mix.py index 733a6c12..84fecc6a 100644 --- a/torch_audiomentations/augmentations/mix.py +++ b/torch_audiomentations/augmentations/mix.py @@ -48,11 +48,17 @@ def __init__( if self.min_snr_in_db > self.max_snr_in_db: raise ValueError("min_snr_in_db must not be greater than max_snr_in_db") - # TODO: make it configuration via a "string" self.mix_target = mix_target - self._mix_target = lambda target, background_target, snr: torch.maximize( - target, background_target - ) + if mix_target == "original": + self._mix_target = lambda target, background_target, snr: target + + elif mix_target == "union": + self._mix_target = lambda target, background_target, snr: torch.maximize( + target, background_target + ) + + else: + raise ValueError("mix_target must be one of 'original' or 'union'.") def randomize_parameters( self, From efe080650ec0c0dc013cda4a268df0b3bf48f1ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Thu, 24 Mar 2022 10:26:06 +0100 Subject: [PATCH 06/15] fix: fix typo --- torch_audiomentations/augmentations/mix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_audiomentations/augmentations/mix.py b/torch_audiomentations/augmentations/mix.py index 84fecc6a..7bdf8681 100644 --- a/torch_audiomentations/augmentations/mix.py +++ b/torch_audiomentations/augmentations/mix.py @@ -53,7 +53,7 @@ def __init__( self._mix_target = lambda target, background_target, snr: target elif mix_target == "union": - self._mix_target = lambda target, background_target, snr: torch.maximize( + self._mix_target = lambda target, background_target, snr: torch.maximum( target, background_target ) From 7b7979a94fba428f6e1cb3448186076734aa20eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Thu, 24 Mar 2022 10:26:16 +0100 Subject: [PATCH 07/15] fix: add Mix to the mix --- torch_audiomentations/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_audiomentations/__init__.py b/torch_audiomentations/__init__.py index aab77651..75c71c1f 100644 --- a/torch_audiomentations/__init__.py +++ b/torch_audiomentations/__init__.py @@ -6,6 +6,7 @@ from .augmentations.high_pass_filter import HighPassFilter from .augmentations.impulse_response import ApplyImpulseResponse from .augmentations.low_pass_filter import LowPassFilter +from .augmentations.mix import Mix from .augmentations.peak_normalization import PeakNormalization from .augmentations.pitch_shift import PitchShift from .augmentations.polarity_inversion import PolarityInversion From b60b07d8721264802baa4355ac128eeca7d35ab7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Thu, 24 Mar 2022 10:29:14 +0100 Subject: [PATCH 08/15] fix: fix composition support for target --- torch_audiomentations/core/composition.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/torch_audiomentations/core/composition.py b/torch_audiomentations/core/composition.py index e360d6c9..9aecb90c 100644 --- a/torch_audiomentations/core/composition.py +++ b/torch_audiomentations/core/composition.py @@ -86,7 +86,11 @@ def forward( else: # FIXME: add support for targets? samples = self.transforms[i](samples) - return samples + + if targets is None: + return samples + else: + return samples, targets class SomeOf(BaseCompose): @@ -167,7 +171,11 @@ def forward( else: # FIXME: add support for targets? samples = self.transforms[i](samples) - return samples + + if targets is None: + return samples + else: + return samples, targets class OneOf(SomeOf): From d05877349766ac93e44bdf9e9efe41d80ae8087b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Thu, 24 Mar 2022 19:05:34 +0100 Subject: [PATCH 09/15] fix: fix two corner cases --- torch_audiomentations/core/transforms_interface.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torch_audiomentations/core/transforms_interface.py b/torch_audiomentations/core/transforms_interface.py index 158a9d5a..c3d4446e 100644 --- a/torch_audiomentations/core/transforms_interface.py +++ b/torch_audiomentations/core/transforms_interface.py @@ -112,7 +112,10 @@ def forward( warnings.warn( "An empty samples tensor was passed to {}".format(self.__class__.__name__) ) - return samples + if targets is None: + return samples + else: + return samples, targets if len(samples.shape) != 3: raise RuntimeError( @@ -445,7 +448,10 @@ def forward( else: raise Exception("Invalid p_mode {}".format(self.p_mode)) - return samples + if targets is None: + return samples + else: + return samples, targets def _forward_unimplemented(self, *inputs) -> None: # Avoid IDE error message like "Class ... must implement all abstract methods" From 30e723b5a264d1794ccc5c51cc9dd12ed8576eb9 Mon Sep 17 00:00:00 2001 From: iver56 Date: Fri, 25 Mar 2022 13:19:21 +0100 Subject: [PATCH 10/15] Add ObjectDict class for results --- torch_audiomentations/utils/object_dict.py | 39 ++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 torch_audiomentations/utils/object_dict.py diff --git a/torch_audiomentations/utils/object_dict.py b/torch_audiomentations/utils/object_dict.py new file mode 100644 index 00000000..3af0754f --- /dev/null +++ b/torch_audiomentations/utils/object_dict.py @@ -0,0 +1,39 @@ +# Inspired by tornado +# https://www.tornadoweb.org/en/stable/_modules/tornado/util.html#ObjectDict + +try: + import typing + from typing import cast + + _ObjectDictBase = typing.Dict[str, typing.Any] +except ImportError: + _ObjectDictBase = dict + + +class ObjectDict(_ObjectDictBase): + """ + Make a dictionary behave like an object, with attribute-style access. + + Here are some examples of how it can be used: + + o = ObjectDict(my_dict) + # or like this: + o = ObjectDict(samples=samples, sample_rate=sample_rate) + + # Attribute-style access + samples = o.samples + + # Dict-style access + samples = o["samples"] + """ + + def __getattr__(self, name): + # type: (str) -> Any + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name, value): + # type: (str, Any) -> None + self[name] = value From 4ec0aaf00db5bf70facf23dd7cc2b61c8585fb5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Tue, 29 Mar 2022 14:30:32 +0200 Subject: [PATCH 11/15] BREAKING: switch to ObjectDict output --- tests/test_background_noise.py | 28 +- tests/test_band_pass_filter.py | 2 +- tests/test_band_stop_filter.py | 2 +- tests/test_colored_noise.py | 12 +- tests/test_compose.py | 24 +- tests/test_differentiable.py | 2 +- tests/test_gain.py | 28 +- tests/test_high_pass_filter.py | 4 +- tests/test_impulse_response.py | 19 +- tests/test_low_pass_filter.py | 2 +- tests/test_peak_normalization.py | 22 +- tests/test_pitch_shift.py | 6 +- tests/test_polarity_inversion.py | 10 +- tests/test_shift.py | 14 +- tests/test_shuffle_channels.py | 12 +- tests/test_someof.py | 16 +- tests/test_time_inversion.py | 4 +- .../augmentations/background_noise.py | 55 +-- .../augmentations/band_pass_filter.py | 56 +-- .../augmentations/band_stop_filter.py | 26 +- .../augmentations/colored_noise.py | 52 +-- torch_audiomentations/augmentations/gain.py | 49 ++- .../augmentations/high_pass_filter.py | 28 +- .../augmentations/impulse_response.py | 52 ++- .../augmentations/low_pass_filter.py | 54 +-- torch_audiomentations/augmentations/mix.py | 58 ++-- .../augmentations/peak_normalization.py | 44 ++- .../augmentations/pitch_shift.py | 61 ++-- .../augmentations/polarity_inversion.py | 35 +- torch_audiomentations/augmentations/shift.py | 66 ++-- .../augmentations/shuffle_channels.py | 59 ++-- .../augmentations/time_inversion.py | 27 +- torch_audiomentations/core/composition.py | 89 ++--- .../core/transforms_interface.py | 318 ++++++++++-------- 34 files changed, 764 insertions(+), 572 deletions(-) diff --git a/tests/test_background_noise.py b/tests/test_background_noise.py index 0a87d50b..b10a904f 100644 --- a/tests/test_background_noise.py +++ b/tests/test_background_noise.py @@ -21,7 +21,7 @@ class TestAddBackgroundNoise(unittest.TestCase): def setUp(self): self.sample_rate = 16000 self.batch_size = 16 - self.empty_input_audio = torch.empty(0) + self.empty_input_audio = torch.empty(0, 1, 16000) # TODO: use utils.io.Audio self.input_audio = ( torch.from_numpy( @@ -47,7 +47,7 @@ def setUp(self): def test_background_noise_no_guarantee_with_single_tensor(self): mixed_input = self.bg_noise_transform_no_guarantee( self.input_audio, self.sample_rate - ) + ).samples self.assertTrue(torch.equal(mixed_input, self.input_audio)) self.assertEqual(mixed_input.size(0), self.input_audio.size(0)) @@ -55,7 +55,7 @@ def test_background_noise_no_guarantee_with_empty_tensor(self): with self.assertWarns(UserWarning) as warning_context_manager: mixed_input = self.bg_noise_transform_no_guarantee( self.empty_input_audio, self.sample_rate - ) + ).samples self.assertIn( "An empty samples tensor was passed", str(warning_context_manager.warning) @@ -69,7 +69,7 @@ def test_background_noise_guaranteed_with_zero_length_samples(self): with self.assertWarns(UserWarning) as warning_context_manager: mixed_input = self.bg_noise_transform_guaranteed( self.empty_input_audio, self.sample_rate - ) + ).samples self.assertIn( "An empty samples tensor was passed", str(warning_context_manager.warning) @@ -81,7 +81,7 @@ def test_background_noise_guaranteed_with_zero_length_samples(self): def test_background_noise_guaranteed_with_single_tensor(self): mixed_input = self.bg_noise_transform_guaranteed( self.input_audio, self.sample_rate - ) + ).samples self.assertFalse(torch.equal(mixed_input, self.input_audio)) self.assertEqual(mixed_input.size(0), self.input_audio.size(0)) self.assertEqual(mixed_input.size(1), self.input_audio.size(1)) @@ -90,7 +90,7 @@ def test_background_noise_guaranteed_with_batched_tensor(self): random.seed(42) mixed_inputs = self.bg_noise_transform_guaranteed( self.input_audios, self.sample_rate - ) + ).samples self.assertFalse(torch.equal(mixed_inputs, self.input_audios)) self.assertEqual(mixed_inputs.size(0), self.input_audios.size(0)) self.assertEqual(mixed_inputs.size(1), self.input_audios.size(1)) @@ -98,7 +98,7 @@ def test_background_noise_guaranteed_with_batched_tensor(self): def test_background_short_noise_guaranteed_with_batched_tensor(self): mixed_input = self.bg_short_noise_transform_guaranteed( self.input_audio, self.sample_rate - ) + ).samples self.assertFalse(torch.equal(mixed_input, self.input_audio)) self.assertEqual(mixed_input.size(0), self.input_audio.size(0)) self.assertEqual(mixed_input.size(1), self.input_audio.size(1)) @@ -108,7 +108,7 @@ def test_background_short_noise_guaranteed_with_batched_cuda_tensor(self): input_audio_cuda = self.input_audio.cuda() mixed_input = self.bg_short_noise_transform_guaranteed( input_audio_cuda, self.sample_rate - ) + ).samples assert not torch.equal(mixed_input, input_audio_cuda) assert mixed_input.shape == input_audio_cuda.shape assert mixed_input.dtype == input_audio_cuda.dtype @@ -120,7 +120,7 @@ def test_varying_snr_within_batch(self): augment = AddBackgroundNoise( self.bg_path, min_snr_in_db=min_snr_in_db, max_snr_in_db=max_snr_in_db, p=1.0 ) - augmented_audios = augment(self.input_audios, self.sample_rate) + augmented_audios = augment(self.input_audios, self.sample_rate).samples self.assertEqual(tuple(augmented_audios.shape), tuple(self.input_audios.shape)) self.assertFalse(torch.equal(augmented_audios, self.input_audios)) @@ -150,7 +150,7 @@ def test_min_equals_max(self): augment = AddBackgroundNoise( self.bg_path, min_snr_in_db=desired_snr, max_snr_in_db=desired_snr, p=1.0 ) - augmented_audios = augment(self.input_audios, self.sample_rate) + augmented_audios = augment(self.input_audios, self.sample_rate).samples self.assertEqual(tuple(augmented_audios.shape), tuple(self.input_audios.shape)) self.assertFalse(torch.equal(augmented_audios, self.input_audios)) @@ -171,11 +171,9 @@ def test_compatibility_of_resampled_length(self): input_sample_rate = random.randint(1000, 5000) bg_sample_rate = random.randint(1000, 5000) - noise = np.random.uniform( - low=-0.2, - high=0.2, - size=(bg_length,), - ).astype(np.float32) + noise = np.random.uniform(low=-0.2, high=0.2, size=(bg_length,),).astype( + np.float32 + ) tmp_dir = os.path.join(tempfile.gettempdir(), str(uuid.uuid4())) try: os.makedirs(tmp_dir) diff --git a/tests/test_band_pass_filter.py b/tests/test_band_pass_filter.py index 98698c19..03dfe59b 100644 --- a/tests/test_band_pass_filter.py +++ b/tests/test_band_pass_filter.py @@ -22,6 +22,6 @@ def test_band_pass_filter(self): for _ in range(20): processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() self.assertEqual(processed_samples.shape, samples.shape) self.assertEqual(processed_samples.dtype, np.float32) diff --git a/tests/test_band_stop_filter.py b/tests/test_band_stop_filter.py index 2a56716d..5eb32a99 100644 --- a/tests/test_band_stop_filter.py +++ b/tests/test_band_stop_filter.py @@ -21,6 +21,6 @@ def test_band_reject_filter(self): augment = BandStopFilter(p=1.0) processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() self.assertEqual(processed_samples.shape, samples.shape) self.assertEqual(processed_samples.dtype, np.float32) diff --git a/tests/test_colored_noise.py b/tests/test_colored_noise.py index cf8ab35c..4d7cfbf5 100644 --- a/tests/test_colored_noise.py +++ b/tests/test_colored_noise.py @@ -13,7 +13,7 @@ def setUp(self): self.sample_rate = 16000 self.audio = Audio(sample_rate=self.sample_rate) self.batch_size = 16 - self.empty_input_audio = torch.empty(0) + self.empty_input_audio = torch.empty(0, 1, 16000) self.input_audio = self.audio( TEST_FIXTURES_DIR / "acoustic_guitar_0.wav" @@ -26,7 +26,7 @@ def setUp(self): def test_colored_noise_no_guarantee_with_single_tensor(self): mixed_input = self.cl_noise_transform_no_guarantee( self.input_audio, self.sample_rate - ) + ).samples self.assertTrue(torch.equal(mixed_input, self.input_audio)) self.assertEqual(mixed_input.size(0), self.input_audio.size(0)) @@ -34,7 +34,7 @@ def test_background_noise_no_guarantee_with_empty_tensor(self): with self.assertWarns(UserWarning) as warning_context_manager: mixed_input = self.cl_noise_transform_no_guarantee( self.empty_input_audio, self.sample_rate - ) + ).samples self.assertIn( "An empty samples tensor was passed", str(warning_context_manager.warning) @@ -48,7 +48,7 @@ def test_colored_noise_guaranteed_with_zero_length_samples(self): with self.assertWarns(UserWarning) as warning_context_manager: mixed_input = self.cl_noise_transform_guaranteed( self.empty_input_audio, self.sample_rate - ) + ).samples self.assertIn( "An empty samples tensor was passed", str(warning_context_manager.warning) @@ -60,7 +60,7 @@ def test_colored_noise_guaranteed_with_zero_length_samples(self): def test_colored_noise_guaranteed_with_single_tensor(self): mixed_input = self.cl_noise_transform_guaranteed( self.input_audio, self.sample_rate - ) + ).samples self.assertFalse(torch.equal(mixed_input, self.input_audio)) self.assertEqual(mixed_input.size(0), self.input_audio.size(0)) self.assertEqual(mixed_input.size(1), self.input_audio.size(1)) @@ -69,7 +69,7 @@ def test_colored_noise_guaranteed_with_batched_tensor(self): random.seed(42) mixed_inputs = self.cl_noise_transform_guaranteed( self.input_audios, self.sample_rate - ) + ).samples self.assertFalse(torch.equal(mixed_inputs, self.input_audios)) self.assertEqual(mixed_inputs.size(0), self.input_audios.size(0)) self.assertEqual(mixed_inputs.size(1), self.input_audios.size(1)) diff --git a/tests/test_compose.py b/tests/test_compose.py index 0a2d3e86..3750978b 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -24,7 +24,7 @@ def test_compose(self): ) processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() expected_factor = -convert_decibels_to_amplitude_ratio(-6) assert_almost_equal( processed_samples, @@ -41,7 +41,7 @@ def test_compose_with_torchaudio_transform(self): augment = Compose([Vol(gain=-6, gain_type="db"), PolarityInversion(p=1.0)]) processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() expected_factor = -convert_decibels_to_amplitude_ratio(-6) assert_almost_equal( processed_samples, @@ -64,7 +64,7 @@ def test_compose_with_p_zero(self): ) processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() assert_array_equal(samples, processed_samples) def test_freeze_and_unfreeze_parameters(self): @@ -82,17 +82,17 @@ def test_freeze_and_unfreeze_parameters(self): processed_samples1 = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() augment.freeze_parameters() processed_samples2 = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() assert_array_equal(processed_samples1, processed_samples2) augment.unfreeze_parameters() processed_samples3 = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() self.assertNotEqual(processed_samples1[0, 0, 0], processed_samples3[0, 0, 0]) def test_shuffle(self): @@ -112,7 +112,7 @@ def test_shuffle(self): for i in range(100): processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() # Either PeakNormalization or Gain was applied last if processed_samples[0, 0, 0] < 0.2: @@ -126,14 +126,8 @@ def test_shuffle(self): self.assertGreater(num_gain_last, 10) def test_supported_modes_property(self): - augment = Compose( - transforms=[ - PeakNormalization(p=1.0), - ], - ) + augment = Compose(transforms=[PeakNormalization(p=1.0),],) assert augment.supported_modes == {"per_batch", "per_example", "per_channel"} - augment = Compose( - transforms=[PeakNormalization(p=1.0), ShuffleChannels(p=1.0)], - ) + augment = Compose(transforms=[PeakNormalization(p=1.0), ShuffleChannels(p=1.0)],) assert augment.supported_modes == {"per_example"} diff --git a/tests/test_differentiable.py b/tests/test_differentiable.py index d0b96eef..29c7ed09 100644 --- a/tests/test_differentiable.py +++ b/tests/test_differentiable.py @@ -65,7 +65,7 @@ def test_transform_is_differentiable(augment): optim = SGD([samples], lr=1.0) for i in range(10): optim.zero_grad() - transformed = augment(samples=samples, sample_rate=sample_rate) + transformed = augment(samples=samples, sample_rate=sample_rate).samples # Compute mean absolute error loss = torch.mean(torch.abs(samples - transformed)) loss.backward() diff --git a/tests/test_gain.py b/tests/test_gain.py index 6510a4a6..a0695a42 100644 --- a/tests/test_gain.py +++ b/tests/test_gain.py @@ -20,7 +20,7 @@ def test_gain(self): augment = Gain(min_gain_in_db=-6.000001, max_gain_in_db=-6, p=1.0) processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() expected_factor = convert_decibels_to_amplitude_ratio(-6) assert_almost_equal( processed_samples, @@ -43,7 +43,7 @@ def test_gain_per_channel(self): ) processed_samples = augment( samples=torch.from_numpy(samples_batch), sample_rate=sample_rate - ).numpy() + ).samples.numpy() num_unprocessed_channels = 0 num_processed_channels = 0 @@ -140,7 +140,7 @@ def test_gain_per_channel_with_p_mode_per_batch(self): for i in range(100): processed_samples = augment( samples=torch.from_numpy(samples_batch), sample_rate=sample_rate - ).numpy() + ).samples.numpy() if np.allclose(processed_samples, samples_batch): continue @@ -235,7 +235,7 @@ def test_gain_per_channel_with_p_mode_per_example(self): ) processed_samples = augment( samples=torch.from_numpy(samples_batch), sample_rate=None - ).numpy() + ).samples.numpy() num_unprocessed_examples = 0 num_processed_examples = 0 @@ -331,7 +331,7 @@ def test_gain_per_batch(self): for i in range(1000): processed_samples = augment( samples=torch.from_numpy(samples_batch), sample_rate=sample_rate - ).numpy() + ).samples.numpy() estimated_gain_factors = processed_samples / samples_batch self.assertAlmostEqual( @@ -354,7 +354,7 @@ def test_eval(self): processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() np.testing.assert_array_equal(samples, processed_samples) @@ -366,7 +366,7 @@ def test_variability_within_batch(self): augment = Gain(min_gain_in_db=-6, max_gain_in_db=6, p=0.5) processed_samples = augment( samples=torch.from_numpy(samples_batch), sample_rate=sample_rate - ).numpy() + ).samples.numpy() self.assertEqual(processed_samples.dtype, np.float32) num_unprocessed_examples = 0 @@ -406,7 +406,7 @@ def test_variability_within_batch_with_p_mode_per_batch(self): for _ in range(100): processed_samples = augment( samples=torch.from_numpy(samples_batch), sample_rate=sample_rate - ).numpy() + ).samples.numpy() self.assertEqual(processed_samples.dtype, np.float32) if np.allclose(processed_samples, samples_batch): @@ -456,7 +456,9 @@ def test_reset_distribution(self): # Change the parameters after init augment.min_gain_in_db = -18 augment.max_gain_in_db = 3 - processed_samples = augment(samples=torch.from_numpy(samples_batch)).numpy() + processed_samples = augment( + samples=torch.from_numpy(samples_batch) + ).samples.numpy() self.assertEqual(processed_samples.dtype, np.float32) actual_gains_in_db = [] @@ -490,7 +492,7 @@ def test_cuda_reset_distribution(self): augment( samples=torch.from_numpy(samples_batch).cuda(), sample_rate=sample_rate ) - .cpu() + .samples.cpu() .numpy() ) self.assertEqual(processed_samples.dtype, np.float32) @@ -538,7 +540,7 @@ def test_gain_to_device_cuda(self): samples=torch.from_numpy(samples).to(device=cuda_device), sample_rate=sample_rate, ) - .cpu() + .samples.cpu() .numpy() ) expected_factor = convert_decibels_to_amplitude_ratio(-6) @@ -558,7 +560,7 @@ def test_gain_cuda(self): augment = Gain(min_gain_in_db=-6.000001, max_gain_in_db=-6, p=1.0).cuda() processed_samples = ( augment(samples=torch.from_numpy(samples).cuda(), sample_rate=sample_rate) - .cpu() + .samples.cpu() .numpy() ) expected_factor = convert_decibels_to_amplitude_ratio(-6) @@ -578,7 +580,7 @@ def test_gain_cuda_cpu(self): augment = Gain(min_gain_in_db=-6.000001, max_gain_in_db=-6, p=1.0).cuda().cpu() processed_samples = ( augment(samples=torch.from_numpy(samples).cpu(), sample_rate=sample_rate) - .cpu() + .samples.cpu() .numpy() ) expected_factor = convert_decibels_to_amplitude_ratio(-6) diff --git a/tests/test_high_pass_filter.py b/tests/test_high_pass_filter.py index fc9d6fa3..d297c0c8 100644 --- a/tests/test_high_pass_filter.py +++ b/tests/test_high_pass_filter.py @@ -22,7 +22,7 @@ def test_high_pass_filter(self): augment = HighPassFilter(p=1.0) processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() self.assertEqual(processed_samples.shape, samples.shape) self.assertEqual(processed_samples.dtype, np.float32) @@ -41,7 +41,7 @@ def test_high_pass_filter_cuda(self): augment = HighPassFilter(p=1.0) processed_samples = ( augment(samples=torch.from_numpy(samples).cuda(), sample_rate=sample_rate) - .cpu() + .samples.cpu() .numpy() ) self.assertEqual(processed_samples.shape, samples.shape) diff --git a/tests/test_impulse_response.py b/tests/test_impulse_response.py index a6c533d4..00ea21b8 100644 --- a/tests/test_impulse_response.py +++ b/tests/test_impulse_response.py @@ -49,14 +49,13 @@ def ir_transform_no_guarantee(ir_path, sample_rate): def test_impulse_response_guaranteed_with_single_tensor_input(ir_transform, input_audio): - mixed_input = ir_transform(input_audio) + mixed_input = ir_transform(input_audio).samples assert mixed_input.shape == input_audio.shape assert not torch.equal(mixed_input, input_audio) @pytest.mark.parametrize( - "compensate_for_propagation_delay", - [False, True], + "compensate_for_propagation_delay", [False, True], ) def test_impulse_response_guaranteed_with_batched_tensor_input( ir_path, sample_rate, input_audios, compensate_for_propagation_delay @@ -66,7 +65,7 @@ def test_impulse_response_guaranteed_with_batched_tensor_input( compensate_for_propagation_delay=compensate_for_propagation_delay, p=1.0, sample_rate=sample_rate, - )(input_audios) + )(input_audios).samples assert mixed_inputs.shape == input_audios.shape assert not torch.equal(mixed_inputs, input_audios) @@ -76,7 +75,7 @@ def test_impulse_response_guaranteed_with_batched_cuda_tensor_input( input_audios, ir_transform ): input_audio_cuda = input_audios.cuda() - mixed_inputs = ir_transform(input_audio_cuda) + mixed_inputs = ir_transform(input_audio_cuda).samples assert not torch.equal(mixed_inputs, input_audio_cuda) assert mixed_inputs.shape == input_audio_cuda.shape assert mixed_inputs.dtype == input_audio_cuda.dtype @@ -86,21 +85,21 @@ def test_impulse_response_guaranteed_with_batched_cuda_tensor_input( def test_impulse_response_no_guarantee_with_single_tensor_input( input_audio, ir_transform_no_guarantee ): - mixed_input = ir_transform_no_guarantee(input_audio) + mixed_input = ir_transform_no_guarantee(input_audio).samples assert mixed_input.shape == input_audio.shape def test_impulse_response_no_guarantee_with_batched_tensor_input( input_audios, ir_transform_no_guarantee ): - mixed_inputs = ir_transform_no_guarantee(input_audios) + mixed_inputs = ir_transform_no_guarantee(input_audios).samples assert mixed_inputs.shape == input_audios.shape def test_impulse_response_guaranteed_with_zero_length_samples(ir_transform): - empty_audio = torch.empty(0) + empty_audio = torch.empty(0, 1, 16000) with pytest.warns(UserWarning, match="An empty samples tensor was passed"): - mixed_inputs = ir_transform(empty_audio) + mixed_inputs = ir_transform(empty_audio).samples assert torch.equal(mixed_inputs, empty_audio) @@ -108,7 +107,7 @@ def test_impulse_response_guaranteed_with_zero_length_samples(ir_transform): def test_impulse_response_access_file_paths(ir_path, sample_rate, input_audios): augment = ApplyImpulseResponse(ir_path, p=1.0, sample_rate=sample_rate) - mixed_inputs = augment(samples=input_audios, sample_rate=sample_rate) + mixed_inputs = augment(samples=input_audios, sample_rate=sample_rate).samples assert mixed_inputs.shape == input_audios.shape diff --git a/tests/test_low_pass_filter.py b/tests/test_low_pass_filter.py index e65708f8..9bb5559e 100644 --- a/tests/test_low_pass_filter.py +++ b/tests/test_low_pass_filter.py @@ -21,6 +21,6 @@ def test_low_pass_filter(self): augment = LowPassFilter(min_cutoff_freq=200, max_cutoff_freq=7000, p=1.0) processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() self.assertEqual(processed_samples.shape, samples.shape) self.assertEqual(processed_samples.dtype, np.float32) diff --git a/tests/test_peak_normalization.py b/tests/test_peak_normalization.py index 7fb3d1fc..278bb63e 100644 --- a/tests/test_peak_normalization.py +++ b/tests/test_peak_normalization.py @@ -23,7 +23,7 @@ def test_apply_to_all(self): augment = PeakNormalization(p=1.0) processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() assert_almost_equal( processed_samples, @@ -52,7 +52,7 @@ def test_apply_to_only_too_loud_sounds(self): augment = PeakNormalization(apply_to="only_too_loud_sounds", p=1.0) processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() assert_almost_equal( processed_samples, @@ -78,7 +78,7 @@ def test_digital_silence_in_batch(self): augment = PeakNormalization(p=1.0) processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() assert_almost_equal( processed_samples, @@ -102,7 +102,7 @@ def test_only_digital_silence(self): augment = PeakNormalization(p=1.0) processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() assert_almost_equal( processed_samples, @@ -127,7 +127,7 @@ def test_never_apply(self): augment = PeakNormalization(p=0.0) processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() assert_equal( processed_samples, @@ -157,7 +157,7 @@ def test_apply_to_all_cuda(self): augment = PeakNormalization(p=1.0) processed_samples = ( augment(samples=torch.from_numpy(samples).cuda(), sample_rate=sample_rate) - .cpu() + .samples.cpu() .numpy() ) @@ -182,7 +182,7 @@ def test_variability_within_batch(self): augment = PeakNormalization(p=0.5) processed_samples = augment( samples=torch.from_numpy(samples_batch), sample_rate=sample_rate - ).numpy() + ).samples.numpy() self.assertEqual(processed_samples.dtype, np.float32) num_unprocessed_examples = 0 @@ -209,11 +209,13 @@ def test_freeze_parameters(self): sample_rate = 16000 augment = PeakNormalization(p=1.0) - _ = augment(samples=torch.from_numpy(samples1), sample_rate=sample_rate).numpy() + _ = augment( + samples=torch.from_numpy(samples1), sample_rate=sample_rate + ).samples.numpy() augment.freeze_parameters() processed_samples2 = augment( samples=torch.from_numpy(samples2), sample_rate=sample_rate - ).numpy() + ).samples.numpy() augment.unfreeze_parameters() assert_almost_equal( @@ -242,7 +244,7 @@ def test_stereo_sound(self): augment = PeakNormalization(apply_to="all", p=1.0) processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() assert_almost_equal( processed_samples, diff --git a/tests/test_pitch_shift.py b/tests/test_pitch_shift.py index a3edcd2e..7d5e384a 100644 --- a/tests/test_pitch_shift.py +++ b/tests/test_pitch_shift.py @@ -23,21 +23,21 @@ def test_per_example_shift(self): samples = get_example() aug = PitchShift(sample_rate=16000, p=1, mode="per_example") aug.randomize_parameters(samples) - results, _ = aug.apply_transform(samples) + results = aug.apply_transform(samples).samples self.assertEqual(results.shape, samples.shape) def test_per_channel_shift(self): samples = get_example() aug = PitchShift(sample_rate=16000, p=1, mode="per_channel") aug.randomize_parameters(samples) - results, _ = aug.apply_transform(samples) + results = aug.apply_transform(samples).samples self.assertEqual(results.shape, samples.shape) def test_per_batch_shift(self): samples = get_example() aug = PitchShift(sample_rate=16000, p=1, mode="per_batch") aug.randomize_parameters(samples) - results, _ = aug.apply_transform(samples) + results = aug.apply_transform(samples).samples self.assertEqual(results.shape, samples.shape) def error_raised(self): diff --git a/tests/test_polarity_inversion.py b/tests/test_polarity_inversion.py index f5e7ed32..59cb5268 100644 --- a/tests/test_polarity_inversion.py +++ b/tests/test_polarity_inversion.py @@ -16,7 +16,7 @@ def test_polarity_inversion(self): augment = PolarityInversion(p=1.0) inverted_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() assert_almost_equal( inverted_samples, np.array([[[-1.0, -0.5, 0.25, 0.125, 0.0]]], dtype=np.float32), @@ -30,7 +30,7 @@ def test_polarity_inversion_zero_probability(self): augment = PolarityInversion(p=0.0) processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() assert_almost_equal( processed_samples, np.array([[[1.0, 0.5, -0.25, -0.125, 0.0]]], dtype=np.float32), @@ -45,7 +45,7 @@ def test_polarity_inversion_variability_within_batch(self): augment = PolarityInversion(p=0.5) processed_samples = augment( samples=torch.from_numpy(samples_batch), sample_rate=sample_rate - ).numpy() + ).samples.numpy() num_unprocessed_examples = 0 num_processed_examples = 0 @@ -71,7 +71,7 @@ def test_polarity_inversion_multichannel(self): augment = PolarityInversion(p=1.0) inverted_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate - ).numpy() + ).samples.numpy() assert_almost_equal( inverted_samples, np.array( @@ -89,7 +89,7 @@ def test_polarity_inversion_cuda(self): augment = PolarityInversion(p=1.0).cuda() inverted_samples = ( augment(samples=torch.from_numpy(samples).cuda(), sample_rate=sample_rate) - .cpu() + .samples.cpu() .numpy() ) assert_almost_equal( diff --git a/tests/test_shift.py b/tests/test_shift.py index 9ba1f012..148c0fce 100644 --- a/tests/test_shift.py +++ b/tests/test_shift.py @@ -27,7 +27,7 @@ def test_shift_by_1_sample_3dim(self, device_name): samples[1] += 1 augment = Shift(min_shift=1, max_shift=1, shift_unit="samples", p=1.0) - processed_samples = augment(samples) + processed_samples = augment(samples).samples assert_almost_equal( processed_samples.cpu(), @@ -42,7 +42,7 @@ def test_shift_by_1_sample_without_rollover(self): min_shift=1, max_shift=1, shift_unit="samples", rollover=False, p=1.0 ) - processed_samples = augment(samples=samples) + processed_samples = augment(samples=samples).samples assert_almost_equal( processed_samples, [[[0, 0, 1, 2], [0, 0, 1, 2]], [[0, 1, 2, 3], [0, 1, 2, 3]]], @@ -56,7 +56,7 @@ def test_negative_shift_by_2_samples(self): min_shift=-2, max_shift=-2, shift_unit="samples", rollover=True, p=1.0 ) - processed_samples = augment(samples=samples) + processed_samples = augment(samples=samples).samples assert_almost_equal( processed_samples, [[[2, 3, 0, 1], [2, 3, 0, 1]], [[3, 4, 1, 2], [3, 4, 1, 2]]], @@ -70,7 +70,7 @@ def test_shift_by_fraction(self): min_shift=0.5, max_shift=0.5, shift_unit="fraction", rollover=True, p=1.0 ) - processed_samples = augment(samples=samples) + processed_samples = augment(samples=samples).samples assert_almost_equal( processed_samples, [[[2, 3, 0, 1], [2, 3, 0, 1]], [[3, 4, 1, 2], [3, 4, 1, 2]]], @@ -83,7 +83,7 @@ def test_shift_by_seconds(self): augment = Shift( min_shift=-2, max_shift=-2, shift_unit="seconds", p=1.0, sample_rate=1 ) - processed_samples = augment(samples) + processed_samples = augment(samples).samples assert_almost_equal( processed_samples, @@ -105,7 +105,7 @@ def test_shift_by_seconds_specify_sample_rate_in_both_init_and_forward(self): rollover=False, ) # If sample_rate is specified in both __init__ and forward, then the latter will be used - processed_samples = augment(samples, sample_rate=forward_sample_rate) + processed_samples = augment(samples, sample_rate=forward_sample_rate).samples assert_almost_equal( processed_samples, [[[0, 0, 0, 1], [0, 0, 0, 1]], [[0, 0, 1, 2], [0, 0, 1, 2]]], @@ -127,7 +127,7 @@ def test_variability_within_batch(self): samples = torch.arange(4)[None, None].repeat(1000, 2, 1) augment = Shift(min_shift=-1, max_shift=1, shift_unit="samples", p=1.0) - processed_samples = augment(samples) + processed_samples = augment(samples).samples applied_shift_counts = {-1: 0, 0: 0, 1: 0} for i in range(samples.shape[0]): diff --git a/tests/test_shuffle_channels.py b/tests/test_shuffle_channels.py index 052fd3c5..e264dafb 100644 --- a/tests/test_shuffle_channels.py +++ b/tests/test_shuffle_channels.py @@ -10,15 +10,12 @@ class TestShuffleChannels: def test_shuffle_mono(self): samples = torch.from_numpy( - np.array( - [[[1.0, -1.0, 1.0, -1.0, 1.0]]], - dtype=np.float32, - ) + np.array([[[1.0, -1.0, 1.0, -1.0, 1.0]]], dtype=np.float32,) ) augment = ShuffleChannels(p=1.0) with pytest.warns(UserWarning): - processed_samples = augment(samples) + processed_samples = augment(samples).samples assert_array_equal(samples.numpy(), processed_samples.numpy()) @@ -39,14 +36,13 @@ def test_variability_within_batch(self, device_name): torch.manual_seed(42) samples = np.array( - [[1.0, -1.0, 1.0, -1.0, 1.0], [0.1, -0.1, 0.1, -0.1, 0.1]], - dtype=np.float32, + [[1.0, -1.0, 1.0, -1.0, 1.0], [0.1, -0.1, 0.1, -0.1, 0.1]], dtype=np.float32, ) samples = np.stack([samples] * 1000, axis=0) samples = torch.from_numpy(samples).to(device) augment = ShuffleChannels(p=1.0) - processed_samples = augment(samples) + processed_samples = augment(samples).samples orders = {"original": 0, "swapped": 0} for i in range(processed_samples.shape[0]): diff --git a/tests/test_someof.py b/tests/test_someof.py index be6ac0d7..a36c4007 100644 --- a/tests/test_someof.py +++ b/tests/test_someof.py @@ -23,21 +23,27 @@ def test_someof(self): augment = SomeOf(2, self.transforms) self.assertEqual(len(augment.transform_indexes), 0) # no transforms applied yet - processed_samples = augment(samples=self.audio, sample_rate=self.sample_rate) + processed_samples = augment( + samples=self.audio, sample_rate=self.sample_rate + ).samples self.assertEqual(len(augment.transform_indexes), 2) # 2 transforms applied def test_someof_with_p_zero(self): augment = SomeOf(2, self.transforms, p=0.0) self.assertEqual(len(augment.transform_indexes), 0) # no transforms applied yet - processed_samples = augment(samples=self.audio, sample_rate=self.sample_rate) + processed_samples = augment( + samples=self.audio, sample_rate=self.sample_rate + ).samples self.assertEqual(len(augment.transform_indexes), 0) # 0 transforms applied def test_someof_tuple(self): augment = SomeOf((1, None), self.transforms) self.assertEqual(len(augment.transform_indexes), 0) # no transforms applied yet - processed_samples = augment(samples=self.audio, sample_rate=self.sample_rate) + processed_samples = augment( + samples=self.audio, sample_rate=self.sample_rate + ).samples self.assertTrue( len(augment.transform_indexes) > 0 ) # at least one transform applied @@ -51,7 +57,7 @@ def test_someof_freeze_and_unfreeze_parameters(self): self.assertEqual(len(augment.transform_indexes), 0) # no transforms applied yet processed_samples1 = augment( samples=samples, sample_rate=self.sample_rate - ).numpy() + ).samples.numpy() transform_indexes1 = augment.transform_indexes self.assertEqual(len(augment.transform_indexes), 2) @@ -59,7 +65,7 @@ def test_someof_freeze_and_unfreeze_parameters(self): processed_samples2 = augment( samples=samples, sample_rate=self.sample_rate - ).numpy() + ).samples.numpy() transform_indexes2 = augment.transform_indexes assert_array_equal(processed_samples1, processed_samples2) assert_array_equal(transform_indexes1, transform_indexes2) diff --git a/tests/test_time_inversion.py b/tests/test_time_inversion.py index b65865b4..ca1bc505 100644 --- a/tests/test_time_inversion.py +++ b/tests/test_time_inversion.py @@ -12,7 +12,7 @@ def setUp(self): def test_single_channel(self): samples = self.samples.unsqueeze(0).unsqueeze(0) # (B, C, T): (1, 1, 100) - processed_samples = self.augment(samples=samples, sample_rate=16000) + processed_samples = self.augment(samples=samples, sample_rate=16000).samples self.assertEqual(processed_samples.shape, samples.shape) self.assertTrue( @@ -25,7 +25,7 @@ def test_multi_channel(self): samples = torch.stack([self.samples, self.samples], dim=0).unsqueeze( 0 ) # (B, C, T): (1, 2, 100) - processed_samples = self.augment(samples=samples, sample_rate=16000) + processed_samples = self.augment(samples=samples, sample_rate=16000).samples self.assertEqual(processed_samples.shape, samples.shape) self.assertTrue( diff --git a/torch_audiomentations/augmentations/background_noise.py b/torch_audiomentations/augmentations/background_noise.py index 47262752..ef5d9ad5 100644 --- a/torch_audiomentations/augmentations/background_noise.py +++ b/torch_audiomentations/augmentations/background_noise.py @@ -1,13 +1,15 @@ import random from pathlib import Path -from typing import Union, List +from typing import Union, List, Optional import torch +from torch import Tensor from ..core.transforms_interface import BaseWaveformTransform, EmptyPathException from ..utils.dsp import calculate_rms from ..utils.file import find_audio_files from ..utils.io import Audio +from ..utils.object_dict import ObjectDict class AddBackgroundNoise(BaseWaveformTransform): @@ -15,11 +17,16 @@ class AddBackgroundNoise(BaseWaveformTransform): Add background noise to the input audio. """ + supported_modes = {"per_batch", "per_example", "per_channel"} + # Note: This transform has only partial support for multichannel audio. Noises that are not # mono get mixed down to mono before they are added to all channels in the input. supports_multichannel = True requires_sample_rate = True + supports_target = True + requires_target = False + def __init__( self, background_paths: Union[List[Path], List[str], Path, str], @@ -102,17 +109,17 @@ def random_background(self, audio: Audio, target_num_samples: int) -> torch.Tens def randomize_parameters( self, - selected_samples: torch.Tensor, - sample_rate: int = None, - targets: torch.Tensor = None, - target_rate: int = None, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, ): """ - :params selected_samples: (batch_size, num_channels, num_samples) + :params samples: (batch_size, num_channels, num_samples) """ - batch_size, _, num_samples = selected_samples.shape + batch_size, _, num_samples = samples.shape # (batch_size, num_samples) RMS-normalized background noise audio = self.audio if hasattr(self, "audio") else Audio(sample_rate, mono=True) @@ -126,19 +133,15 @@ def randomize_parameters( size=(batch_size,), fill_value=self.min_snr_in_db, dtype=torch.float32, - device=selected_samples.device, + device=samples.device, ) else: snr_distribution = torch.distributions.Uniform( low=torch.tensor( - self.min_snr_in_db, - dtype=torch.float32, - device=selected_samples.device, + self.min_snr_in_db, dtype=torch.float32, device=samples.device, ), high=torch.tensor( - self.max_snr_in_db, - dtype=torch.float32, - device=selected_samples.device, + self.max_snr_in_db, dtype=torch.float32, device=samples.device, ), validate_args=True, ) @@ -148,24 +151,26 @@ def randomize_parameters( def apply_transform( self, - selected_samples: torch.Tensor, - sample_rate: int = None, - targets: torch.Tensor = None, - target_rate: int = None, - ): - batch_size, num_channels, num_samples = selected_samples.shape + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: + batch_size, num_channels, num_samples = samples.shape # (batch_size, num_samples) - background = self.transform_parameters["background"].to(selected_samples.device) + background = self.transform_parameters["background"].to(samples.device) # (batch_size, num_channels) - background_rms = calculate_rms(selected_samples) / ( + background_rms = calculate_rms(samples) / ( 10 ** (self.transform_parameters["snr_in_db"].unsqueeze(dim=-1) / 20) ) - return ( - selected_samples + return ObjectDict( + samples=samples + background_rms.unsqueeze(-1) * background.view(batch_size, 1, num_samples).expand(-1, num_channels, -1), - targets, + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, ) diff --git a/torch_audiomentations/augmentations/band_pass_filter.py b/torch_audiomentations/augmentations/band_pass_filter.py index 37e7e512..8c58623f 100644 --- a/torch_audiomentations/augmentations/band_pass_filter.py +++ b/torch_audiomentations/augmentations/band_pass_filter.py @@ -1,8 +1,10 @@ import julius import torch - +from torch import Tensor +from typing import Optional from ..core.transforms_interface import BaseWaveformTransform from ..utils.mel_scale import convert_frequencies_to_mels, convert_mels_to_frequencies +from ..utils.object_dict import ObjectDict class BandPassFilter(BaseWaveformTransform): @@ -10,9 +12,14 @@ class BandPassFilter(BaseWaveformTransform): Apply band-pass filtering to the input audio. """ + supported_modes = {"per_batch", "per_example", "per_channel"} + supports_multichannel = True requires_sample_rate = True + supports_target = True + requires_target = False + def __init__( self, min_center_frequency=200, @@ -73,28 +80,25 @@ def __init__( def randomize_parameters( self, - selected_samples: torch.Tensor, - sample_rate: int = None, - targets: torch.Tensor = None, - target_rate: int = None, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, ): """ - :params selected_samples: (batch_size, num_channels, num_samples) + :params samples: (batch_size, num_channels, num_samples) """ - batch_size, _, num_samples = selected_samples.shape + + batch_size, _, num_samples = samples.shape # Sample frequencies uniformly in mel space, then convert back to frequency def get_dist(min_freq, max_freq): dist = torch.distributions.Uniform( low=convert_frequencies_to_mels( - torch.tensor( - min_freq, dtype=torch.float32, device=selected_samples.device, - ) + torch.tensor(min_freq, dtype=torch.float32, device=samples.device,) ), high=convert_frequencies_to_mels( - torch.tensor( - max_freq, dtype=torch.float32, device=selected_samples.device, - ) + torch.tensor(max_freq, dtype=torch.float32, device=samples.device,) ), validate_args=True, ) @@ -114,15 +118,12 @@ def get_dist(min_freq, max_freq): def apply_transform( self, - selected_samples: torch.Tensor, - sample_rate: int = None, - targets: torch.Tensor = None, - target_rate: int = None, - ): - batch_size, num_channels, num_samples = selected_samples.shape - - if sample_rate is None: - sample_rate = self.sample_rate + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: + batch_size, num_channels, num_samples = samples.shape low_cutoffs_as_fraction_of_sample_rate = ( self.transform_parameters["center_freq"] @@ -136,10 +137,15 @@ def apply_transform( ) # TODO: Instead of using a for loop, perform batched compute to speed things up for i in range(batch_size): - selected_samples[i] = julius.bandpass_filter( - selected_samples[i], + samples[i] = julius.bandpass_filter( + samples[i], cutoff_low=low_cutoffs_as_fraction_of_sample_rate[i].item(), cutoff_high=high_cutoffs_as_fraction_of_sample_rate[i].item(), ) - return selected_samples, targets + return ObjectDict( + samples=samples, + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, + ) diff --git a/torch_audiomentations/augmentations/band_stop_filter.py b/torch_audiomentations/augmentations/band_stop_filter.py index 991b89e7..272d9f1e 100644 --- a/torch_audiomentations/augmentations/band_stop_filter.py +++ b/torch_audiomentations/augmentations/band_stop_filter.py @@ -1,6 +1,8 @@ -import torch +from torch import Tensor +from typing import Optional from ..augmentations.band_pass_filter import BandPassFilter +from ..utils.object_dict import ObjectDict class BandStopFilter(BandPassFilter): @@ -9,9 +11,6 @@ class BandStopFilter(BandPassFilter): band reject filter and frequency mask. """ - supports_multichannel = True - requires_sample_rate = True - def __init__( self, min_center_frequency=200, @@ -52,15 +51,18 @@ def __init__( def apply_transform( self, - selected_samples: torch.Tensor, - sample_rate: int = None, - targets: torch.Tensor = None, - target_rate: int = None, - ): - band_pass_filtered_samples, band_pass_filtered_targets = super().apply_transform( - selected_samples.clone(), + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: + + perturbed = super().apply_transform( + samples.clone(), sample_rate, targets=targets.clone() if targets is not None else None, target_rate=target_rate, ) - return selected_samples - band_pass_filtered_samples, band_pass_filtered_targets + + perturbed.samples = samples - perturbed.samples + return perturbed diff --git a/torch_audiomentations/augmentations/colored_noise.py b/torch_audiomentations/augmentations/colored_noise.py index 6dd30bf6..88066705 100644 --- a/torch_audiomentations/augmentations/colored_noise.py +++ b/torch_audiomentations/augmentations/colored_noise.py @@ -1,10 +1,13 @@ import torch +from torch import Tensor +from typing import Optional from math import ceil from torch_audiomentations.utils.fft import rfft, irfft from ..core.transforms_interface import BaseWaveformTransform from ..utils.dsp import calculate_rms from ..utils.io import Audio +from ..utils.object_dict import ObjectDict def _gen_noise(f_decay, num_samples, sample_rate, device): @@ -33,9 +36,14 @@ class AddColoredNoise(BaseWaveformTransform): Add colored noises to the input audio. """ + supported_modes = {"per_batch", "per_example", "per_channel"} + supports_multichannel = True requires_sample_rate = True + supports_target = True + requires_target = False + def __init__( self, min_snr_in_db: float = 3.0, @@ -85,15 +93,15 @@ def __init__( def randomize_parameters( self, - selected_samples: torch.Tensor, - sample_rate: int = None, - targets: torch.Tensor = None, - target_rate: int = None, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, ): """ :params selected_samples: (batch_size, num_channels, num_samples) """ - batch_size, _, num_samples = selected_samples.shape + batch_size, _, num_samples = samples.shape # (batch_size, ) SNRs for param, mini, maxi in [ @@ -101,27 +109,21 @@ def randomize_parameters( ("f_decay", self.min_f_decay, self.max_f_decay), ]: dist = torch.distributions.Uniform( - low=torch.tensor( - mini, dtype=torch.float32, device=selected_samples.device - ), - high=torch.tensor( - maxi, dtype=torch.float32, device=selected_samples.device - ), + low=torch.tensor(mini, dtype=torch.float32, device=samples.device), + high=torch.tensor(maxi, dtype=torch.float32, device=samples.device), validate_args=True, ) self.transform_parameters[param] = dist.sample(sample_shape=(batch_size,)) def apply_transform( self, - selected_samples: torch.Tensor, - sample_rate: int = None, - targets: torch.Tensor = None, - target_rate: int = None, - ): - batch_size, num_channels, num_samples = selected_samples.shape + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: - if sample_rate is None: - sample_rate = self.sample_rate + batch_size, num_channels, num_samples = samples.shape # (batch_size, num_samples) noise = torch.stack( @@ -130,20 +132,22 @@ def apply_transform( self.transform_parameters["f_decay"][i], num_samples, sample_rate, - selected_samples.device, + samples.device, ) for i in range(batch_size) ] ) # (batch_size, num_channels) - noise_rms = calculate_rms(selected_samples) / ( + noise_rms = calculate_rms(samples) / ( 10 ** (self.transform_parameters["snr_in_db"].unsqueeze(dim=-1) / 20) ) - return ( - selected_samples + return ObjectDict( + samples=samples + noise_rms.unsqueeze(-1) * noise.view(batch_size, 1, num_samples).expand(-1, num_channels, -1), - targets, + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, ) diff --git a/torch_audiomentations/augmentations/gain.py b/torch_audiomentations/augmentations/gain.py index ad49709a..f647987d 100644 --- a/torch_audiomentations/augmentations/gain.py +++ b/torch_audiomentations/augmentations/gain.py @@ -1,8 +1,10 @@ import torch -import typing +from torch import Tensor +from typing import Optional from ..core.transforms_interface import BaseWaveformTransform from ..utils.dsp import convert_decibels_to_amplitude_ratio +from ..utils.object_dict import ObjectDict class Gain(BaseWaveformTransform): @@ -15,17 +17,23 @@ class Gain(BaseWaveformTransform): See also https://en.wikipedia.org/wiki/Clipping_(audio)#Digital_clipping """ + supported_modes = {"per_batch", "per_example", "per_channel"} + + supports_multichannel = True requires_sample_rate = False + supports_target = True + requires_target = False + def __init__( self, min_gain_in_db: float = -18.0, max_gain_in_db: float = 6.0, mode: str = "per_example", p: float = 0.5, - p_mode: typing.Optional[str] = None, - sample_rate: typing.Optional[int] = None, - target_rate: typing.Optional[int] = None, + p_mode: Optional[str] = None, + sample_rate: Optional[int] = None, + target_rate: Optional[int] = None, ): super().__init__( mode=mode, @@ -41,21 +49,21 @@ def __init__( def randomize_parameters( self, - selected_samples: torch.Tensor, - sample_rate: int = None, - targets: torch.Tensor = None, - target_rate: int = None, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, ): distribution = torch.distributions.Uniform( low=torch.tensor( - self.min_gain_in_db, dtype=torch.float32, device=selected_samples.device + self.min_gain_in_db, dtype=torch.float32, device=samples.device ), high=torch.tensor( - self.max_gain_in_db, dtype=torch.float32, device=selected_samples.device + self.max_gain_in_db, dtype=torch.float32, device=samples.device ), validate_args=True, ) - selected_batch_size = selected_samples.size(0) + selected_batch_size = samples.size(0) self.transform_parameters["gain_factors"] = ( convert_decibels_to_amplitude_ratio( distribution.sample(sample_shape=(selected_batch_size,)) @@ -66,9 +74,16 @@ def randomize_parameters( def apply_transform( self, - selected_samples: torch.Tensor, - sample_rate: int = None, - targets: torch.Tensor = None, - target_rate: int = None, - ): - return selected_samples * self.transform_parameters["gain_factors"], targets + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: + + return ObjectDict( + samples=samples * self.transform_parameters["gain_factors"], + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, + ) + diff --git a/torch_audiomentations/augmentations/high_pass_filter.py b/torch_audiomentations/augmentations/high_pass_filter.py index 1b0a7c5a..9975318f 100644 --- a/torch_audiomentations/augmentations/high_pass_filter.py +++ b/torch_audiomentations/augmentations/high_pass_filter.py @@ -1,6 +1,8 @@ -import torch +from torch import Tensor +from typing import Optional from ..augmentations.low_pass_filter import LowPassFilter +from ..utils.object_dict import ObjectDict class HighPassFilter(LowPassFilter): @@ -8,9 +10,6 @@ class HighPassFilter(LowPassFilter): Apply high-pass filtering to the input audio. """ - supports_multichannel = True - requires_sample_rate = True - def __init__( self, min_cutoff_freq=20, @@ -43,15 +42,18 @@ def __init__( def apply_transform( self, - selected_samples: torch.Tensor, - sample_rate: int = None, - targets: torch.Tensor = None, - target_rate: int = None, - ): - low_pass_filtered_samples, low_pass_filtered_targets = super().apply_transform( - selected_samples.clone(), - sample_rate, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: + + perturbed = super().apply_transform( + samples=samples.clone(), + sample_rate=sample_rate, targets=targets.clone() if targets is not None else None, target_rate=target_rate, ) - return selected_samples - low_pass_filtered_samples, low_pass_filtered_targets + + perturbed.samples = samples - perturbed.samples + return perturbed diff --git a/torch_audiomentations/augmentations/impulse_response.py b/torch_audiomentations/augmentations/impulse_response.py index ee12d1b5..db60f38d 100644 --- a/torch_audiomentations/augmentations/impulse_response.py +++ b/torch_audiomentations/augmentations/impulse_response.py @@ -1,6 +1,7 @@ import random from pathlib import Path -from typing import Union, List +from typing import Union, List, Optional +from torch import Tensor import torch from torch.nn.utils.rnn import pad_sequence @@ -9,6 +10,7 @@ from ..utils.convolution import convolve from ..utils.file import find_audio_files from ..utils.io import Audio +from ..utils.object_dict import ObjectDict class ApplyImpulseResponse(BaseWaveformTransform): @@ -16,11 +18,16 @@ class ApplyImpulseResponse(BaseWaveformTransform): Convolve the given audio with impulse responses. """ + supported_modes = {"per_batch", "per_example", "per_channel"} + # Note: This transform has only partial support for multichannel audio. IRs that are not # mono get mixed down to mono before they are convolved with all channels in the input. supports_multichannel = True requires_sample_rate = True + supports_target = False # FIXME: some work is needed to support targets (see FIXMEs in apply_transform) + requires_target = False + def __init__( self, ir_paths: Union[List[Path], List[str], Path, str], @@ -72,13 +79,13 @@ def __init__( def randomize_parameters( self, - selected_samples: torch.Tensor, - sample_rate: int = None, - targets: torch.Tensor = None, - target_rate: int = None, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, ): - batch_size, _, _ = selected_samples.shape + batch_size, _, _ = samples.shape audio = self.audio if hasattr(self, "audio") else Audio(sample_rate, mono=True) @@ -94,19 +101,19 @@ def randomize_parameters( def apply_transform( self, - selected_samples: torch.Tensor, - sample_rate: int = None, - targets: torch.Tensor = None, - target_rate: int = None, - ): + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: - batch_size, num_channels, num_samples = selected_samples.shape + batch_size, num_channels, num_samples = samples.shape # (batch_size, 1, max_ir_length) - ir = self.transform_parameters["ir"].to(selected_samples.device) + ir = self.transform_parameters["ir"].to(samples.device) convolved_samples = convolve( - selected_samples, ir.expand(-1, num_channels, -1), mode=self.convolve_mode + samples, ir.expand(-1, num_channels, -1), mode=self.convolve_mode ) if self.compensate_for_propagation_delay: @@ -123,9 +130,18 @@ def apply_transform( dim=0, ) - # FIXME should we compensate targets as well? - return convolved_samples, targets + return ObjectDict( + samples=convolved_samples, + sample_rate=sample_rate, + targets=targets, # FIXME compensate targets as well? + target_rate=target_rate, + ) else: - # FIXME should we strip targets as well? - return convolved_samples[..., :num_samples], targets + return ObjectDict( + samples=convolved_samples[..., :num_samples], + sample_rate=sample_rate, + targets=targets, # FIXME crop targets as well? + target_rate=target_rate, + ) + diff --git a/torch_audiomentations/augmentations/low_pass_filter.py b/torch_audiomentations/augmentations/low_pass_filter.py index ace821e6..4382cc49 100644 --- a/torch_audiomentations/augmentations/low_pass_filter.py +++ b/torch_audiomentations/augmentations/low_pass_filter.py @@ -1,8 +1,12 @@ import julius import torch +from torch import Tensor +from typing import Optional + from ..core.transforms_interface import BaseWaveformTransform from ..utils.mel_scale import convert_frequencies_to_mels, convert_mels_to_frequencies +from ..utils.object_dict import ObjectDict class LowPassFilter(BaseWaveformTransform): @@ -10,9 +14,14 @@ class LowPassFilter(BaseWaveformTransform): Apply low-pass filtering to the input audio. """ + supported_modes = {"per_batch", "per_example", "per_channel"} + supports_multichannel = True requires_sample_rate = True + supports_target = True + requires_target = False + def __init__( self, min_cutoff_freq=150, @@ -46,30 +55,26 @@ def __init__( def randomize_parameters( self, - selected_samples: torch.Tensor, - sample_rate: int = None, - targets: torch.Tensor = None, - target_rate: int = None, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, ): """ - :params selected_samples: (batch_size, num_channels, num_samples) + :params samples: (batch_size, num_channels, num_samples) """ - batch_size, _, num_samples = selected_samples.shape + batch_size, _, num_samples = samples.shape # Sample frequencies uniformly in mel space, then convert back to frequency dist = torch.distributions.Uniform( low=convert_frequencies_to_mels( torch.tensor( - self.min_cutoff_freq, - dtype=torch.float32, - device=selected_samples.device, + self.min_cutoff_freq, dtype=torch.float32, device=samples.device, ) ), high=convert_frequencies_to_mels( torch.tensor( - self.max_cutoff_freq, - dtype=torch.float32, - device=selected_samples.device, + self.max_cutoff_freq, dtype=torch.float32, device=samples.device, ) ), validate_args=True, @@ -80,23 +85,26 @@ def randomize_parameters( def apply_transform( self, - selected_samples: torch.Tensor, - sample_rate: int = None, - targets: torch.Tensor = None, - target_rate: int = None, - ): - batch_size, num_channels, num_samples = selected_samples.shape + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: - if sample_rate is None: - sample_rate = self.sample_rate + batch_size, num_channels, num_samples = samples.shape cutoffs_as_fraction_of_sample_rate = ( self.transform_parameters["cutoff_freq"] / sample_rate ) # TODO: Instead of using a for loop, perform batched compute to speed things up for i in range(batch_size): - selected_samples[i] = julius.lowpass_filter( - selected_samples[i], cutoffs_as_fraction_of_sample_rate[i].item() + samples[i] = julius.lowpass_filter( + samples[i], cutoffs_as_fraction_of_sample_rate[i].item() ) - return selected_samples, targets + return ObjectDict( + samples=samples, + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, + ) diff --git a/torch_audiomentations/augmentations/mix.py b/torch_audiomentations/augmentations/mix.py index 7bdf8681..f73678de 100644 --- a/torch_audiomentations/augmentations/mix.py +++ b/torch_audiomentations/augmentations/mix.py @@ -1,9 +1,11 @@ from typing import Optional import torch +from torch import Tensor from ..core.transforms_interface import BaseWaveformTransform from ..utils.dsp import calculate_rms from ..utils.io import Audio +from ..utils.object_dict import ObjectDict class Mix(BaseWaveformTransform): @@ -19,11 +21,13 @@ class Mix(BaseWaveformTransform): """ - supports_multichannel = True supported_modes = {"per_example", "per_channel"} + + supports_multichannel = True requires_sample_rate = False - requires_targets = False - requires_target_rate = False + + supports_target = True + requires_target = False def __init__( self, @@ -62,19 +66,19 @@ def __init__( def randomize_parameters( self, - selected_samples, + samples: Tensor = None, sample_rate: Optional[int] = None, - targets=None, + targets: Optional[Tensor] = None, target_rate: Optional[int] = None, ): - batch_size, num_channels, num_samples = selected_samples.shape + batch_size, num_channels, num_samples = samples.shape snr_distribution = torch.distributions.Uniform( low=torch.tensor( - self.min_snr_in_db, dtype=torch.float32, device=selected_samples.device, + self.min_snr_in_db, dtype=torch.float32, device=samples.device, ), high=torch.tensor( - self.max_snr_in_db, dtype=torch.float32, device=selected_samples.device, + self.max_snr_in_db, dtype=torch.float32, device=samples.device, ), validate_args=True, ) @@ -86,31 +90,35 @@ def randomize_parameters( # randomize index of second sample self.transform_parameters["sample_idx"] = torch.randint( - 0, batch_size, (batch_size,), device=selected_samples.device, + 0, batch_size, (batch_size,), device=samples.device, ) def apply_transform( self, - selected_samples: torch.Tensor, - sample_rate: int = None, - targets: torch.Tensor = None, - target_rate: int = None, - ): + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: snr = self.transform_parameters["snr_in_db"] idx = self.transform_parameters["sample_idx"] - background_samples = Audio.rms_normalize(selected_samples[idx]) - background_rms = calculate_rms(selected_samples) / ( - 10 ** (snr.unsqueeze(dim=-1) / 20) - ) + background_samples = Audio.rms_normalize(samples[idx]) + background_rms = calculate_rms(samples) / (10 ** (snr.unsqueeze(dim=-1) / 20)) + + mixed_samples = samples + background_rms.unsqueeze(-1) * background_samples - perturbed_samples = ( - selected_samples + background_rms.unsqueeze(-1) * background_samples - ) if targets is None: - return perturbed_samples + mixed_targets = None + + else: + background_targets = targets[idx] + mixed_targets = self._mix_target(targets, background_targets, snr) - background_targets = targets[idx] - perturbed_targets = self._mix_target(targets, background_targets, snr) - return perturbed_samples, perturbed_targets + return ObjectDict( + samples=mixed_samples, + sample_rate=sample_rate, + targets=mixed_targets, + target_rate=target_rate, + ) diff --git a/torch_audiomentations/augmentations/peak_normalization.py b/torch_audiomentations/augmentations/peak_normalization.py index 0dadfcb8..963a1724 100644 --- a/torch_audiomentations/augmentations/peak_normalization.py +++ b/torch_audiomentations/augmentations/peak_normalization.py @@ -1,7 +1,10 @@ import torch import typing +from typing import Optional +from torch import Tensor from ..core.transforms_interface import BaseWaveformTransform +from ..utils.object_dict import ObjectDict class PeakNormalization(BaseWaveformTransform): @@ -16,8 +19,14 @@ class PeakNormalization(BaseWaveformTransform): untouched. """ + supported_modes = {"per_batch", "per_example", "per_channel"} + + supports_multichannel = True requires_sample_rate = False + supports_target = True + requires_target = False + def __init__( self, apply_to="all", @@ -39,13 +48,13 @@ def __init__( def randomize_parameters( self, - selected_samples: torch.Tensor, - sample_rate: int = None, - targets: torch.Tensor = None, - target_rate: int = None, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, ): # Compute the most extreme value of each multichannel audio snippet in the batch - most_extreme_values, _ = torch.max(torch.abs(selected_samples), dim=-1) + most_extreme_values, _ = torch.max(torch.abs(samples), dim=-1) most_extreme_values, _ = torch.max(most_extreme_values, dim=-1) if self.apply_to == "all": @@ -68,13 +77,20 @@ def randomize_parameters( def apply_transform( self, - selected_samples: torch.Tensor, - sample_rate: int = None, - targets: torch.Tensor = None, - target_rate: int = None, - ): + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: + if "divisors" in self.transform_parameters: - selected_samples[ - self.transform_parameters["selector"] - ] /= self.transform_parameters["divisors"] - return selected_samples, targets + samples[self.transform_parameters["selector"]] /= self.transform_parameters[ + "divisors" + ] + + return ObjectDict( + samples=samples, + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, + ) diff --git a/torch_audiomentations/augmentations/pitch_shift.py b/torch_audiomentations/augmentations/pitch_shift.py index bde673bc..dfda3da6 100644 --- a/torch_audiomentations/augmentations/pitch_shift.py +++ b/torch_audiomentations/augmentations/pitch_shift.py @@ -1,9 +1,11 @@ from random import choices -import torch +from torch import Tensor +from typing import Optional from torch_pitch_shift import pitch_shift, get_fast_shifts, semitones_to_ratio from ..core.transforms_interface import BaseWaveformTransform +from ..utils.object_dict import ObjectDict class PitchShift(BaseWaveformTransform): @@ -11,9 +13,14 @@ class PitchShift(BaseWaveformTransform): Pitch-shift sounds up or down without changing the tempo. """ + supported_modes = {"per_batch", "per_example", "per_channel"} + supports_multichannel = True requires_sample_rate = True + supports_target = True + requires_target = False + def __init__( self, min_transpose_semitones: float = -4.0, @@ -60,16 +67,16 @@ def __init__( def randomize_parameters( self, - selected_samples: torch.Tensor, - sample_rate: int = None, - targets: torch.Tensor = None, - target_rate: int = None, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, ): """ - :param selected_samples: (batch_size, num_channels, num_samples) + :param samples: (batch_size, num_channels, num_samples) :param sample_rate: """ - batch_size, num_channels, num_samples = selected_samples.shape + batch_size, num_channels, num_samples = samples.shape if self._mode == "per_example": self.transform_parameters["transpositions"] = choices( @@ -89,16 +96,16 @@ def randomize_parameters( def apply_transform( self, - selected_samples: torch.Tensor, - sample_rate: int = None, - targets: torch.Tensor = None, - target_rate: int = None, - ): + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: """ - :param selected_samples: (batch_size, num_channels, num_samples) + :param samples: (batch_size, num_channels, num_samples) :param sample_rate: """ - batch_size, num_channels, num_samples = selected_samples.shape + batch_size, num_channels, num_samples = samples.shape if sample_rate is not None and sample_rate != self._sample_rate: raise ValueError( @@ -109,27 +116,29 @@ def apply_transform( if self._mode == "per_example": for i in range(batch_size): - selected_samples[i, ...] = pitch_shift( - selected_samples[i][None], + samples[i, ...] = pitch_shift( + samples[i][None], self.transform_parameters["transpositions"][i], sample_rate, )[0] + elif self._mode == "per_channel": for i in range(batch_size): for j in range(num_channels): - selected_samples[i, j, ...] = pitch_shift( - selected_samples[i][j][None][None], + samples[i, j, ...] = pitch_shift( + samples[i][j][None][None], self.transform_parameters["transpositions"][i][j], sample_rate, )[0][0] + elif self._mode == "per_batch": - return ( - pitch_shift( - selected_samples, - self.transform_parameters["transpositions"][0], - sample_rate, - ), - targets, + samples = pitch_shift( + samples, self.transform_parameters["transpositions"][0], sample_rate ) - return selected_samples, targets + return ObjectDict( + samples=samples, + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, + ) diff --git a/torch_audiomentations/augmentations/polarity_inversion.py b/torch_audiomentations/augmentations/polarity_inversion.py index 4894afc3..53461cb5 100644 --- a/torch_audiomentations/augmentations/polarity_inversion.py +++ b/torch_audiomentations/augmentations/polarity_inversion.py @@ -1,7 +1,8 @@ -import typing -import torch +from torch import Tensor +from typing import Optional from ..core.transforms_interface import BaseWaveformTransform +from ..utils.object_dict import ObjectDict class PolarityInversion(BaseWaveformTransform): @@ -15,16 +16,21 @@ class PolarityInversion(BaseWaveformTransform): training phase-aware machine learning models. """ + supported_modes = {"per_batch", "per_example", "per_channel"} + supports_multichannel = True requires_sample_rate = False + supports_target = True + requires_target = False + def __init__( self, mode: str = "per_example", p: float = 0.5, - p_mode: typing.Optional[str] = None, - sample_rate: typing.Optional[int] = None, - target_rate: typing.Optional[int] = None, + p_mode: Optional[str] = None, + sample_rate: Optional[int] = None, + target_rate: Optional[int] = None, ): super().__init__( mode=mode, @@ -36,9 +42,16 @@ def __init__( def apply_transform( self, - selected_samples: torch.Tensor, - sample_rate: int = None, - targets: torch.Tensor = None, - target_rate: int = None, - ): - return -selected_samples, targets + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: + + return ObjectDict( + samples=-samples, + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, + ) + diff --git a/torch_audiomentations/augmentations/shift.py b/torch_audiomentations/augmentations/shift.py index f32c562e..fd471717 100644 --- a/torch_audiomentations/augmentations/shift.py +++ b/torch_audiomentations/augmentations/shift.py @@ -1,7 +1,9 @@ import torch -import typing +from typing import Optional +from torch import Tensor from ..core.transforms_interface import BaseWaveformTransform +from ..utils.object_dict import ObjectDict def shift_gpu(tensor: torch.Tensor, r: torch.Tensor, rollover: bool = False): @@ -51,6 +53,14 @@ class Shift(BaseWaveformTransform): Shift the audio forwards or backwards, with or without rollover """ + supported_modes = {"per_batch", "per_example", "per_channel"} + + supports_multichannel = True + requires_sample_rate = True + + supports_target = False # FIXME: some work is needed to support targets (see FIXMEs in apply_transform) + requires_target = False + def __init__( self, min_shift: float = -0.5, @@ -59,9 +69,9 @@ def __init__( rollover: bool = True, mode: str = "per_example", p: float = 0.5, - p_mode: typing.Optional[str] = None, - sample_rate: typing.Optional[int] = None, - target_rate: typing.Optional[int] = None, + p_mode: Optional[str] = None, + sample_rate: Optional[int] = None, + target_rate: Optional[int] = None, ): """ @@ -99,18 +109,18 @@ def __init__( def randomize_parameters( self, - selected_samples: torch.Tensor, - sample_rate: int = None, - targets: torch.Tensor = None, - target_rate: int = None, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, ): if self.shift_unit == "samples": min_shift_in_samples = self.min_shift max_shift_in_samples = self.max_shift elif self.shift_unit == "fraction": - min_shift_in_samples = int(round(self.min_shift * selected_samples.shape[-1])) - max_shift_in_samples = int(round(self.max_shift * selected_samples.shape[-1])) + min_shift_in_samples = int(round(self.min_shift * samples.shape[-1])) + max_shift_in_samples = int(round(self.max_shift * samples.shape[-1])) elif self.shift_unit == "seconds": min_shift_in_samples = int(round(self.min_shift * sample_rate)) @@ -129,13 +139,13 @@ def randomize_parameters( <= max_shift_in_samples <= torch.iinfo(torch.int32).max ) - selected_batch_size = selected_samples.size(0) + selected_batch_size = samples.size(0) if min_shift_in_samples == max_shift_in_samples: self.transform_parameters["num_samples_to_shift"] = torch.full( size=(selected_batch_size,), fill_value=min_shift_in_samples, dtype=torch.int32, - device=selected_samples.device, + device=samples.device, ) else: @@ -144,28 +154,27 @@ def randomize_parameters( high=max_shift_in_samples + 1, size=(selected_batch_size,), dtype=torch.int32, - device=selected_samples.device, + device=samples.device, ) def apply_transform( self, - selected_samples: torch.Tensor, - sample_rate: int = None, - targets: torch.Tensor = None, - target_rate: int = None, - ): + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: num_samples_to_shift = self.transform_parameters["num_samples_to_shift"] # Select fastest implementation based on device - shift = shift_gpu if selected_samples.device.type == "cuda" else shift_cpu - shifted_samples = shift(selected_samples, num_samples_to_shift, self.rollover) + shift = shift_gpu if samples.device.type == "cuda" else shift_cpu + shifted_samples = shift(samples, num_samples_to_shift, self.rollover) - if targets is None: + if targets is None or target_rate == 0: shifted_targets = targets + else: - # FIXME corner case where target_rate is missing - # FIXME corner case where target is not correlated with the input length num_frames_to_shift = int( round(target_rate * num_samples_to_shift / sample_rate) ) @@ -173,12 +182,13 @@ def apply_transform( targets.transpose(-2, -1), num_frames_to_shift, self.rollover ).transpose(-2, -1) - return shifted_samples, shifted_targets + return ObjectDict( + samples=shifted_samples, + sample_rate=sample_rate, + targets=shifted_targets, + target_rate=target_rate, + ) def is_sample_rate_required(self) -> bool: # Sample rate is required only if shift_unit is "seconds" return self.shift_unit == "seconds" - - def is_target_rate_required(self) -> bool: - # FIXME should be True only when targets is passed to apply_transform - return self.requires_target_rate diff --git a/torch_audiomentations/augmentations/shuffle_channels.py b/torch_audiomentations/augmentations/shuffle_channels.py index b016d2a9..ad1c15eb 100644 --- a/torch_audiomentations/augmentations/shuffle_channels.py +++ b/torch_audiomentations/augmentations/shuffle_channels.py @@ -1,9 +1,12 @@ -import typing +from typing import Optional import warnings import torch +from torch import Tensor + from ..core.transforms_interface import BaseWaveformTransform +from ..utils.object_dict import ObjectDict class ShuffleChannels(BaseWaveformTransform): @@ -22,9 +25,9 @@ def __init__( self, mode: str = "per_example", p: float = 0.5, - p_mode: typing.Optional[str] = None, - sample_rate: typing.Optional[int] = None, - target_rate: typing.Optional[int] = None, + p_mode: Optional[str] = None, + sample_rate: Optional[int] = None, + target_rate: Optional[int] = None, ): super().__init__( mode=mode, @@ -36,41 +39,49 @@ def __init__( def randomize_parameters( self, - selected_samples: torch.Tensor, - sample_rate: int = None, - targets: torch.Tensor = None, - target_rate: int = None, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, ): - batch_size = selected_samples.shape[0] - num_channels = selected_samples.shape[1] + batch_size = samples.shape[0] + num_channels = samples.shape[1] assert num_channels <= 255 permutations = torch.zeros( - (batch_size, num_channels), dtype=torch.int64, device=selected_samples.device + (batch_size, num_channels), dtype=torch.int64, device=samples.device ) for i in range(batch_size): - permutations[i] = torch.randperm(num_channels, device=selected_samples.device) + permutations[i] = torch.randperm(num_channels, device=samples.device) self.transform_parameters["permutations"] = permutations def apply_transform( self, - selected_samples: torch.Tensor, - sample_rate: int = None, - targets: torch.Tensor = None, - target_rate: int = None, - ): + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: - if selected_samples.shape[1] == 1: + if samples.shape[1] == 1: warnings.warn( "Mono audio was passed to ShuffleChannels - there are no channels to shuffle." " The input will be returned unchanged." ) - return selected_samples, targets + return ObjectDict( + samples=samples, + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, + ) - for i in range(selected_samples.size(0)): - selected_samples[i] = selected_samples[ - i, self.transform_parameters["permutations"][i] - ] + for i in range(samples.size(0)): + samples[i] = samples[i, self.transform_parameters["permutations"][i]] if targets is not None: targets[i] = targets[i, self.transform_parameters["permutations"][i]] - return selected_samples, targets + return ObjectDict( + samples=samples, + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, + ) diff --git a/torch_audiomentations/augmentations/time_inversion.py b/torch_audiomentations/augmentations/time_inversion.py index c4c16542..c1ef6676 100644 --- a/torch_audiomentations/augmentations/time_inversion.py +++ b/torch_audiomentations/augmentations/time_inversion.py @@ -1,6 +1,9 @@ import torch +from torch import Tensor +from typing import Optional from ..core.transforms_interface import BaseWaveformTransform +from ..utils.object_dict import ObjectDict class TimeInversion(BaseWaveformTransform): @@ -12,9 +15,14 @@ class TimeInversion(BaseWaveformTransform): https://arxiv.org/pdf/2106.13043.pdf """ + supported_modes = {"per_batch", "per_example", "per_channel"} + supports_multichannel = True requires_sample_rate = False + supports_target = True + requires_target = False + def __init__( self, mode: str = "per_example", @@ -39,11 +47,11 @@ def __init__( def apply_transform( self, - selected_samples: torch.Tensor, - sample_rate: int = None, - targets: torch.Tensor = None, - target_rate: int = None, - ): + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: # torch.flip() is supposed to be slower than np.flip() # An alternative is to use advanced indexing: https://github.com/pytorch/pytorch/issues/16424 @@ -51,10 +59,15 @@ def apply_transform( # transformed_samples = selected_samples[..., reverse_index] # return transformed_samples - flipped_samples = torch.flip(selected_samples, dims=(-1,)) + flipped_samples = torch.flip(samples, dims=(-1,)) if targets is None: flipped_targets = targets else: flipped_targets = torch.flip(targets, dims=(-2,)) - return flipped_samples, flipped_targets + return ObjectDict( + samples=flipped_samples, + sample_rate=sample_rate, + targets=flipped_targets, + target_rate=target_rate, + ) diff --git a/torch_audiomentations/core/composition.py b/torch_audiomentations/core/composition.py index 9aecb90c..ab9443ff 100644 --- a/torch_audiomentations/core/composition.py +++ b/torch_audiomentations/core/composition.py @@ -1,10 +1,11 @@ import random -from typing import List +from typing import List, Union, Optional, Tuple -import torch -import typing +from torch import Tensor +import torch.nn from torch_audiomentations.core.transforms_interface import BaseWaveformTransform +from torch_audiomentations.utils.object_dict import ObjectDict class BaseCompose(torch.nn.Module): @@ -12,7 +13,9 @@ class BaseCompose(torch.nn.Module): def __init__( self, - transforms: List[torch.nn.Module], + transforms: List[ + torch.nn.Module + ], # FIXME: do we really want to support regular nn.Module? shuffle: bool = False, p: float = 1.0, p_mode="per_batch", @@ -65,32 +68,34 @@ def supported_modes(self) -> set: class Compose(BaseCompose): def forward( self, - samples, - sample_rate: typing.Optional[int] = None, - targets=None, - target_rate: typing.Optional[int] = None, - ): + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: + + inputs = ObjectDict( + samples=samples, + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, + ) + if random.random() < self.p: transform_indexes = list(range(len(self.transforms))) if self.shuffle: random.shuffle(transform_indexes) for i in transform_indexes: tfm = self.transforms[i] - if isinstance(tfm, BaseWaveformTransform): - if targets is None: - samples = self.transforms[i](samples, sample_rate) - else: - samples, targets = self.transforms[i]( - samples, sample_rate, targets=targets, target_rate=target_rate - ) + if isinstance(tfm, (BaseWaveformTransform, BaseCompose)): + inputs = self.transforms[i](**inputs) + else: - # FIXME: add support for targets? - samples = self.transforms[i](samples) + assert isinstance(tfm, torch.nn.Module) + # FIXME: do we really want to support regular nn.Module? + inputs.samples = self.transforms[i](inputs.samples) - if targets is None: - return samples - else: - return samples, targets + return inputs class SomeOf(BaseCompose): @@ -112,7 +117,7 @@ class SomeOf(BaseCompose): def __init__( self, - num_transforms: typing.Union[int, typing.Tuple[int, int]], + num_transforms: Union[int, Tuple[int, int]], transforms: List[torch.nn.Module], p: float = 1.0, p_mode="per_batch", @@ -149,11 +154,19 @@ def randomize_parameters(self): def forward( self, - samples, - sample_rate: typing.Optional[int] = None, - targets=None, - target_rate: typing.Optional[int] = None, - ): + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: + + inputs = ObjectDict( + samples=samples, + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, + ) + if random.random() < self.p: if not self.are_parameters_frozen: @@ -161,21 +174,15 @@ def forward( for i in self.transform_indexes: tfm = self.transforms[i] - if isinstance(tfm, BaseWaveformTransform): - if targets is None: - samples = self.transforms[i](samples, sample_rate) - else: - samples, targets = self.transforms[i]( - samples, sample_rate, targets=targets, target_rate=target_rate - ) + if isinstance(tfm, (BaseWaveformTransform, BaseCompose)): + inputs = self.transforms[i](**inputs) + else: - # FIXME: add support for targets? - samples = self.transforms[i](samples) + assert isinstance(tfm, torch.nn.Module) + # FIXME: do we really want to support regular nn.Module? + inputs.samples = self.transforms[i](inputs.samples) - if targets is None: - return samples - else: - return samples, targets + return inputs class OneOf(SomeOf): diff --git a/torch_audiomentations/core/transforms_interface.py b/torch_audiomentations/core/transforms_interface.py index c3d4446e..bab32c7a 100644 --- a/torch_audiomentations/core/transforms_interface.py +++ b/torch_audiomentations/core/transforms_interface.py @@ -1,10 +1,12 @@ import warnings import torch -import typing +from torch import Tensor +from typing import Optional from torch.distributions import Bernoulli from torch_audiomentations.utils.multichannel import is_multichannel +from torch_audiomentations.utils.object_dict import ObjectDict class MultichannelAudioNotSupportedException(Exception): @@ -21,19 +23,21 @@ class ModeNotSupportedException(Exception): class BaseWaveformTransform(torch.nn.Module): - supports_multichannel = True supported_modes = {"per_batch", "per_example", "per_channel"} + + supports_multichannel = True requires_sample_rate = True - requires_targets = False - requires_target_rate = False + + supports_target = True + requires_target = False def __init__( self, mode: str = "per_example", p: float = 0.5, - p_mode: typing.Optional[str] = None, - sample_rate: typing.Optional[int] = None, - target_rate: typing.Optional[int] = None, + p_mode: Optional[str] = None, + sample_rate: Optional[int] = None, + target_rate: Optional[int] = None, ): """ @@ -96,36 +100,42 @@ def p(self, p): def forward( self, - samples, - sample_rate: typing.Optional[int] = None, - targets=None, - target_rate: typing.Optional[int] = None, - ): + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + # TODO: add support for additional **kwargs (batch_size, ...)-shaped tensors + # TODO: but do that only when we actually have a use case for it... + ) -> ObjectDict: if not self.training: - if targets is None: - return samples - else: - return samples, targets - - if len(samples) == 0: - warnings.warn( - "An empty samples tensor was passed to {}".format(self.__class__.__name__) + return ObjectDict( + samples=samples, + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, ) - if targets is None: - return samples - else: - return samples, targets - if len(samples.shape) != 3: + if not isinstance(samples, Tensor) or len(samples.shape) != 3: raise RuntimeError( - "torch-audiomentations expects input tensors to be three-dimensional, with" + "torch-audiomentations expects three-dimensional input tensors, with" " dimension ordering like [batch_size, num_channels, num_samples]. If your" " audio is mono, you can use a shape like [batch_size, 1, num_samples]." ) batch_size, num_channels, num_samples = samples.shape + if batch_size * num_channels * num_samples == 0: + warnings.warn( + "An empty samples tensor was passed to {}".format(self.__class__.__name__) + ) + return ObjectDict( + samples=samples, + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, + ) + if is_multichannel(samples): if num_channels > num_samples: warnings.warn( @@ -133,6 +143,7 @@ def forward( " other words, the shape must be (batch size, channels, samples), not" " (batch_size, samples, channels)" ) + if not self.supports_multichannel: raise MultichannelAudioNotSupportedException( "{} only supports mono audio, not multichannel audio".format( @@ -144,34 +155,54 @@ def forward( if sample_rate is None and self.is_sample_rate_required(): raise RuntimeError("sample_rate is required") - if targets is None and self.is_targets_required(): + if targets is None and self.is_target_required(): raise RuntimeError("targets is required") - if targets is not None: + has_targets = targets is not None + + if has_targets and not self.supports_target: + warnings.warn(f"Targets are not (yet) supported by {self.__class__.__name__}") + + if has_targets: + + if not isinstance(targets, Tensor) or len(targets.shape) != 4: - if len(targets.shape) != 4: raise RuntimeError( - "torch-audiomentations expects target tensors to be four-dimensional, with" + "torch-audiomentations expects four-dimensional target tensors, with" " dimension ordering like [batch_size, num_channels, num_frames, num_classes]." " If your target is binary, you can use a shape like [batch_size, num_channels, num_frames, 1]." " If your target is for the whole channel, you can use a shape like [batch_size, num_channels, 1, num_classes]." ) - batch_size_, num_channels_, num_frames, num_classes = targets.shape + ( + target_batch_size, + target_num_channels, + num_frames, + num_classes, + ) = targets.shape - if batch_size_ != batch_size: + if target_batch_size != batch_size: raise RuntimeError( - f"samples ({batch_size}) and target ({batch_size_}) batch sizes must be equal." + f"samples ({batch_size}) and target ({target_batch_size}) batch sizes must be equal." ) - if num_channels != num_channels_: + if num_channels != target_num_channels: raise RuntimeError( - f"samples ({num_channels}) and target ({num_channels_}) number of channels must be equal." + f"samples ({num_channels}) and target ({target_num_channels}) number of channels must be equal." ) target_rate = target_rate or self.target_rate - if target_rate is None and self.is_target_rate_required(): - # IDEA: automatically estimate target_rate based on samples, sample_rate, and targets - raise RuntimeError("target_rate is required") + if target_rate is None: + if num_frames > 1: + target_rate = round(sample_rate * num_frames / num_samples) + warnings.warn( + f"target_rate is required by {self.__class__.__name__}. " + f"It has been automatically inferred from targets shape to {target_rate}. " + f"If this is incorrect, you can pass it directly." + ) + else: + # corner case where num_frames == 1, meaning that the target is for the whole sample, + # not frame-based. we arbitrarily set target_rate to 0. + target_rate = 0 if not self.are_parameters_frozen: @@ -197,11 +228,11 @@ def forward( cloned_samples = samples.clone() - if targets is None: + if has_targets: + cloned_targets = targets.clone() + else: cloned_targets = None selected_targets = None - else: - cloned_targets = targets.clone() if self.p_mode == "per_channel": @@ -212,7 +243,7 @@ def forward( self.transform_parameters["should_apply"] ] - if targets is not None: + if has_targets: cloned_targets = cloned_targets.reshape( batch_size * num_channels, 1, num_frames, num_classes ) @@ -222,14 +253,14 @@ def forward( if not self.are_parameters_frozen: self.randomize_parameters( - selected_samples, + samples=selected_samples, sample_rate=sample_rate, targets=selected_targets, target_rate=target_rate, ) - perturbed_samples, perturbed_targets = self.apply_transform( - selected_samples, + perturbed: ObjectDict = self.apply_transform( + samples=selected_samples, sample_rate=sample_rate, targets=selected_targets, target_rate=target_rate, @@ -237,22 +268,25 @@ def forward( cloned_samples[ self.transform_parameters["should_apply"] - ] = perturbed_samples + ] = perturbed.samples cloned_samples = cloned_samples.reshape( batch_size, num_channels, num_samples ) - if targets is None: - return cloned_samples - - else: + if has_targets: cloned_targets[ self.transform_parameters["should_apply"] - ] = perturbed_targets + ] = perturbed.targets cloned_targets = cloned_targets.reshape( batch_size, num_channels, num_frames, num_classes ) - return cloned_samples, cloned_targets + + return ObjectDict( + samples=cloned_samples, + sample_rate=perturbed.sample_rate, + targets=cloned_targets, + target_rate=perturbed.target_rate, + ) elif self.p_mode == "per_example": @@ -260,7 +294,7 @@ def forward( self.transform_parameters["should_apply"] ] - if targets is not None: + if has_targets: selected_targets = cloned_targets[ self.transform_parameters["should_apply"] ] @@ -269,14 +303,14 @@ def forward( if not self.are_parameters_frozen: self.randomize_parameters( - selected_samples, - sample_rate, + samples=selected_samples, + sample_rate=sample_rate, targets=selected_targets, target_rate=target_rate, ) - perturbed_samples, perturbed_targets = self.apply_transform( - selected_samples, + perturbed: ObjectDict = self.apply_transform( + samples=selected_samples, sample_rate=sample_rate, targets=selected_targets, target_rate=target_rate, @@ -284,59 +318,83 @@ def forward( cloned_samples[ self.transform_parameters["should_apply"] - ] = perturbed_samples + ] = perturbed.samples - if targets is None: - return cloned_samples - - else: + if has_targets: cloned_targets[ self.transform_parameters["should_apply"] - ] = perturbed_targets - return cloned_samples, cloned_targets + ] = perturbed.targets + + return ObjectDict( + samples=cloned_samples, + sample_rate=perturbed.sample_rate, + targets=cloned_targets, + target_rate=perturbed.target_rate, + ) elif self.mode == "per_channel": - b, c, s = selected_samples.shape + ( + selected_batch_size, + selected_num_channels, + selected_num_samples, + ) = selected_samples.shape - selected_samples = selected_samples.reshape(b * c, 1, s) + assert selected_num_samples == num_samples + + selected_samples = selected_samples.reshape( + selected_batch_size * selected_num_channels, + 1, + selected_num_samples, + ) - if targets is not None: + if has_targets: selected_targets = selected_targets.reshape( - b * c, 1, num_frames, num_classes + selected_batch_size * selected_num_channels, + 1, + num_frames, + num_classes, ) if not self.are_parameters_frozen: self.randomize_parameters( - selected_samples, - sample_rate, + samples=selected_samples, + sample_rate=sample_rate, targets=selected_targets, target_rate=target_rate, ) - perturbed_samples, perturbed_targets = self.apply_transform( + perturbed: ObjectDict = self.apply_transform( selected_samples, sample_rate=sample_rate, targets=selected_targets, target_rate=target_rate, ) - perturbed_samples = perturbed_samples.reshape(b, c, s) + perturbed.samples = perturbed.samples.reshape( + selected_batch_size, selected_num_channels, selected_num_samples + ) cloned_samples[ self.transform_parameters["should_apply"] - ] = perturbed_samples - - if targets is None: - return cloned_samples - - else: - perturbed_targets = perturbed_targets.reshape( - b, c, num_frames, num_classes + ] = perturbed.samples + + if has_targets: + perturbed.targets = perturbed.targets.reshape( + selected_batch_size, + selected_num_channels, + num_frames, + num_classes, ) cloned_targets[ self.transform_parameters["should_apply"] - ] = perturbed_targets - return cloned_samples, cloned_targets + ] = perturbed.targets + + return ObjectDict( + samples=cloned_samples, + sample_rate=perturbed.sample_rate, + targets=cloned_targets, + target_rate=perturbed.target_rate, + ) else: raise Exception("Invalid mode/p_mode combination") @@ -349,60 +407,54 @@ def forward( 1, batch_size * num_channels, num_samples ) - if targets is not None: + if has_targets: cloned_targets = cloned_targets.reshape( 1, batch_size * num_channels, num_frames, num_classes ) if not self.are_parameters_frozen: self.randomize_parameters( - cloned_samples, - sample_rate, + samples=cloned_samples, + sample_rate=sample_rate, targets=cloned_targets, target_rate=target_rate, ) - perturbed_samples, perturbed_targets = self.apply_transform( - cloned_samples, - sample_rate, + perturbed: ObjectDict = self.apply_transform( + samples=cloned_samples, + sample_rate=sample_rate, targets=cloned_targets, target_rate=target_rate, ) - perturbed_samples = perturbed_samples.reshape( + perturbed.samples = perturbed.samples.reshape( batch_size, num_channels, num_samples ) - if targets is None: - return perturbed_samples - - else: - perturbed_targets = perturbed_targets.reshape( + if has_targets: + perturbed.targets = perturbed.targets.reshape( batch_size, num_channels, num_frames, num_classes ) - return perturbed_samples, perturbed_targets + + return perturbed elif self.mode == "per_example": if not self.are_parameters_frozen: self.randomize_parameters( - cloned_samples, - sample_rate, + samples=cloned_samples, + sample_rate=sample_rate, targets=cloned_targets, target_rate=target_rate, ) - perturbed_samples, perturbed_targets = self.apply_transform( - cloned_samples, - sample_rate, + perturbed = self.apply_transform( + samples=cloned_samples, + sample_rate=sample_rate, targets=cloned_targets, target_rate=target_rate, ) - if targets is None: - return perturbed_samples - - else: - return perturbed_samples, perturbed_targets + return perturbed elif self.mode == "per_channel": @@ -410,48 +462,49 @@ def forward( batch_size * num_channels, 1, num_samples ) - if targets is not None: + if has_targets: cloned_targets = cloned_targets.reshape( batch_size * num_channels, 1, num_frames, num_classes ) if not self.are_parameters_frozen: self.randomize_parameters( - cloned_samples, - sample_rate, + samples=cloned_samples, + sample_rate=sample_rate, targets=cloned_targets, target_rate=target_rate, ) - perturbed_samples, perturbed_targets = self.apply_transform( + perturbed: ObjectDict = self.apply_transform( cloned_samples, sample_rate, targets=cloned_targets, target_rate=target_rate, ) - perturbed_samples = perturbed_samples.reshape( + perturbed.samples = perturbed.samples.reshape( batch_size, num_channels, num_samples ) - if targets is None: - return perturbed_samples - - else: - perturbed_targets = perturbed_targets.reshape( + if has_targets: + perturbed.targets = perturbed.targets.reshape( batch_size, num_channels, num_frames, num_classes ) - return perturbed_samples, perturbed_targets + return perturbed + else: raise Exception("Invalid mode") + else: raise Exception("Invalid p_mode {}".format(self.p_mode)) - if targets is None: - return samples - else: - return samples, targets + return ObjectDict( + samples=samples, + sample_rate=sample_rate, + targets=targets, + target_rate=target_rate, + ) def _forward_unimplemented(self, *inputs) -> None: # Avoid IDE error message like "Class ... must implement all abstract methods" @@ -460,20 +513,20 @@ def _forward_unimplemented(self, *inputs) -> None: def randomize_parameters( self, - selected_samples, - sample_rate: typing.Optional[int] = None, - targets=None, - target_rate: typing.Optional[int] = None, + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, ): pass def apply_transform( self, - selected_samples, - sample_rate: typing.Optional[int] = None, - targets=None, - target_rate: typing.Optional[int] = None, - ): + samples: Tensor = None, + sample_rate: Optional[int] = None, + targets: Optional[Tensor] = None, + target_rate: Optional[int] = None, + ) -> ObjectDict: raise NotImplementedError() @@ -499,8 +552,5 @@ def unfreeze_parameters(self): def is_sample_rate_required(self) -> bool: return self.requires_sample_rate - def is_targets_required(self) -> bool: - return self.requires_targets - - def is_target_rate_required(self) -> bool: - return self.requires_target_rate + def is_target_required(self) -> bool: + return self.requires_target From 21e54e8036b111bad4a94c1168ddedffa41cefe1 Mon Sep 17 00:00:00 2001 From: iver56 Date: Wed, 30 Mar 2022 14:53:22 +0200 Subject: [PATCH 12/15] Remove try-except for ObjectDict compatibility with old python versions --- torch_audiomentations/utils/object_dict.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/torch_audiomentations/utils/object_dict.py b/torch_audiomentations/utils/object_dict.py index 3af0754f..f59ad186 100644 --- a/torch_audiomentations/utils/object_dict.py +++ b/torch_audiomentations/utils/object_dict.py @@ -1,13 +1,9 @@ # Inspired by tornado # https://www.tornadoweb.org/en/stable/_modules/tornado/util.html#ObjectDict -try: - import typing - from typing import cast +import typing - _ObjectDictBase = typing.Dict[str, typing.Any] -except ImportError: - _ObjectDictBase = dict +_ObjectDictBase = typing.Dict[str, typing.Any] class ObjectDict(_ObjectDictBase): From fc37f041132ddfe7addd7afddcd3e245f14265b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Thu, 31 Mar 2022 16:37:00 +0200 Subject: [PATCH 13/15] feat: add test_varying_snr_within_batch test for Mix augmentation --- tests/test_mix.py | 65 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 tests/test_mix.py diff --git a/tests/test_mix.py b/tests/test_mix.py new file mode 100644 index 00000000..846e6edc --- /dev/null +++ b/tests/test_mix.py @@ -0,0 +1,65 @@ +import os +import random +import shutil +import tempfile +import unittest +import uuid +from pathlib import Path + +import numpy as np +import pytest +import torch +from scipy.io.wavfile import write + +from torch_audiomentations import Mix +from torch_audiomentations.utils.dsp import calculate_rms +from torch_audiomentations.utils.file import load_audio +from .utils import TEST_FIXTURES_DIR + + +class TestMix(unittest.TestCase): + def setUp(self): + self.sample_rate = 16000 + self.guitar = ( + torch.from_numpy( + load_audio( + TEST_FIXTURES_DIR / "acoustic_guitar_0.wav", + sample_rate=self.sample_rate, + ) + ) + .unsqueeze(0) + .unsqueeze(0) + ) + self.noise = ( + torch.from_numpy( + load_audio( + TEST_FIXTURES_DIR / "bg" / "bg.wav", sample_rate=self.sample_rate, + ) + ) + .unsqueeze(0) + .unsqueeze(0) + ) + + common_num_samples = min(self.guitar.shape[-1], self.noise.shape[-1]) + self.guitar = self.guitar[:, :, :common_num_samples] + self.noise = self.noise[:, :, :common_num_samples] + self.input_audios = torch.cat([self.guitar, self.noise], dim=0) + + def test_varying_snr_within_batch(self): + min_snr_in_db = 3 + max_snr_in_db = 30 + augment = Mix(min_snr_in_db=min_snr_in_db, max_snr_in_db=max_snr_in_db, p=1.0) + mixed_audios = augment(self.input_audios, self.sample_rate).samples + + self.assertEqual(tuple(mixed_audios.shape), tuple(self.input_audios.shape)) + self.assertFalse(torch.equal(mixed_audios, self.input_audios)) + + added_audios = mixed_audios - self.input_audios + + for i in range(len(self.input_audios)): + signal_rms = calculate_rms(self.input_audios[i]) + added_rms = calculate_rms(added_audios[i]) + snr_in_db = 20 * torch.log10(signal_rms / added_rms).item() + self.assertGreaterEqual(snr_in_db, min_snr_in_db) + self.assertLessEqual(snr_in_db, max_snr_in_db) + From 932e5623b8af7782a7741bb8ce5aab965b7316ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Thu, 31 Mar 2022 16:37:43 +0200 Subject: [PATCH 14/15] fix: remove useless imports --- tests/test_mix.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/test_mix.py b/tests/test_mix.py index 846e6edc..02fa92a1 100644 --- a/tests/test_mix.py +++ b/tests/test_mix.py @@ -1,15 +1,6 @@ -import os -import random -import shutil -import tempfile import unittest -import uuid -from pathlib import Path -import numpy as np -import pytest import torch -from scipy.io.wavfile import write from torch_audiomentations import Mix from torch_audiomentations.utils.dsp import calculate_rms From d25249795ff2fb68ab9ff5de774b6cabe8c02ab2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Thu, 31 Mar 2022 16:57:02 +0200 Subject: [PATCH 15/15] feat: add targets-related tests in Mix --- tests/test_mix.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/test_mix.py b/tests/test_mix.py index 02fa92a1..da95a7f3 100644 --- a/tests/test_mix.py +++ b/tests/test_mix.py @@ -33,8 +33,20 @@ def setUp(self): common_num_samples = min(self.guitar.shape[-1], self.noise.shape[-1]) self.guitar = self.guitar[:, :, :common_num_samples] + + self.guitar_target = torch.zeros( + (1, 1, common_num_samples // 7, 2), dtype=torch.int64 + ) + self.guitar_target[:, :, :, 0] = 1 + self.noise = self.noise[:, :, :common_num_samples] + self.noise_target = torch.zeros( + (1, 1, common_num_samples // 7, 2), dtype=torch.int64 + ) + self.noise_target[:, :, :, 1] = 1 + self.input_audios = torch.cat([self.guitar, self.noise], dim=0) + self.input_targets = torch.cat([self.guitar_target, self.noise_target], dim=0) def test_varying_snr_within_batch(self): min_snr_in_db = 3 @@ -54,3 +66,32 @@ def test_varying_snr_within_batch(self): self.assertGreaterEqual(snr_in_db, min_snr_in_db) self.assertLessEqual(snr_in_db, max_snr_in_db) + def test_targets_union(self): + augment = Mix(p=1.0, mix_target="union") + mixtures = augment( + samples=self.input_audios, + sample_rate=self.sample_rate, + targets=self.input_targets, + ) + mixed_targets = mixtures.targets + + # check guitar target is still active in first (guitar) sample + self.assertTrue( + torch.equal(mixed_targets[0, :, :, 0], self.input_targets[0, :, :, 0]) + ) + # check noise target is still active in second (noise) sample + self.assertTrue( + torch.equal(mixed_targets[1, :, :, 1], self.input_targets[1, :, :, 1]) + ) + + def test_targets_original(self): + augment = Mix(p=1.0, mix_target="original") + mixtures = augment( + samples=self.input_audios, + sample_rate=self.sample_rate, + targets=self.input_targets, + ) + mixed_targets = mixtures.targets + + self.assertTrue(torch.equal(mixed_targets, self.input_targets)) +