From 5ec7408fa1da45aec6c9c798aebd0e0ecdf6d077 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Tue, 29 Mar 2022 09:49:53 +0530 Subject: [PATCH 01/14] random crop --- .../augmentations/random_crop.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 torch_audiomentations/augmentations/random_crop.py diff --git a/torch_audiomentations/augmentations/random_crop.py b/torch_audiomentations/augmentations/random_crop.py new file mode 100644 index 00000000..c7a7edf5 --- /dev/null +++ b/torch_audiomentations/augmentations/random_crop.py @@ -0,0 +1,40 @@ +import torch +import typing + +class RandomCrop(torch.nn.Module): + + requires_sample_rate = True + + def __init__( + self, + seconds: int, + sampling_rate: int + ): + self.sampling_rate = sampling_rate + self.num_samples = self.sampling_rate * seconds + + + def forward(self, samples, sampling_rate: typing.Optional[int] = None): + + sample_rate = sampling_rate or self.sampling_rate + if sample_rate is None: + raise RuntimeError("sample_rate is required") + + sample_length = samples.shape[2] / sample_rate + if sample_length < self.num_samples: + self.num_samples = sample_length + raise RuntimeWarning("audio length less than cropping length") + + + start_indices = torch.randint(0,samples.shape[2] - self.num_samples,(sample_length.shape[0],)) + samples_cropped = torch.empty((samples.shape[0],samples.shape[1],self.num_samples)) + for i,sample in enumerate(samples): + samples_cropped[i] = sample[:,:,start_indices[i]:start_indices[i]+self.num_samples] + + return samples_cropped + + + + + + From 1c399248736a43e7fc9a0f87d887865c1706ab74 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Tue, 29 Mar 2022 10:03:09 +0530 Subject: [PATCH 02/14] init baseclass --- torch_audiomentations/augmentations/random_crop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_audiomentations/augmentations/random_crop.py b/torch_audiomentations/augmentations/random_crop.py index c7a7edf5..e72c953d 100644 --- a/torch_audiomentations/augmentations/random_crop.py +++ b/torch_audiomentations/augmentations/random_crop.py @@ -7,9 +7,10 @@ class RandomCrop(torch.nn.Module): def __init__( self, - seconds: int, + seconds: float, sampling_rate: int ): + super(RandomCrop,self).__init__() self.sampling_rate = sampling_rate self.num_samples = self.sampling_rate * seconds From 4769c56eac1058ea195db61bf95525e76880f1a6 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Tue, 29 Mar 2022 10:12:54 +0530 Subject: [PATCH 03/14] type conversion --- torch_audiomentations/augmentations/random_crop.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torch_audiomentations/augmentations/random_crop.py b/torch_audiomentations/augmentations/random_crop.py index e72c953d..642e8fa8 100644 --- a/torch_audiomentations/augmentations/random_crop.py +++ b/torch_audiomentations/augmentations/random_crop.py @@ -12,7 +12,7 @@ def __init__( ): super(RandomCrop,self).__init__() self.sampling_rate = sampling_rate - self.num_samples = self.sampling_rate * seconds + self.num_samples = int(self.sampling_rate * seconds) def forward(self, samples, sampling_rate: typing.Optional[int] = None): @@ -21,16 +21,16 @@ def forward(self, samples, sampling_rate: typing.Optional[int] = None): if sample_rate is None: raise RuntimeError("sample_rate is required") - sample_length = samples.shape[2] / sample_rate - if sample_length < self.num_samples: - self.num_samples = sample_length + if samples.shape[2] < self.num_samples: + self.num_samples = samples.shape[2] raise RuntimeWarning("audio length less than cropping length") - start_indices = torch.randint(0,samples.shape[2] - self.num_samples,(sample_length.shape[0],)) + start_indices = torch.randint(0,samples.shape[2] - self.num_samples,(samples.shape[2],)) samples_cropped = torch.empty((samples.shape[0],samples.shape[1],self.num_samples)) for i,sample in enumerate(samples): - samples_cropped[i] = sample[:,:,start_indices[i]:start_indices[i]+self.num_samples] + + samples_cropped[i] = sample.unsqueeze(0)[:,:,start_indices[i]:start_indices[i]+self.num_samples] return samples_cropped From b6c796e579cf9ef09cd7cbf2f73c2956459ef3f6 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Tue, 29 Mar 2022 10:13:17 +0530 Subject: [PATCH 04/14] test random crop --- tests/test_random_crop.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 tests/test_random_crop.py diff --git a/tests/test_random_crop.py b/tests/test_random_crop.py new file mode 100644 index 00000000..f8503535 --- /dev/null +++ b/tests/test_random_crop.py @@ -0,0 +1,23 @@ +import unittest +import torch +from torch_audiomentations.augmentations.random_crop import RandomCrop +torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +class TestRandomCrop(unittest.TestCase): + + def testcrop(self): + samples = torch.rand(size=(8, 2, 32000), dtype=torch.float32, device=torch_device) - 0.5 + sampling_rate = 16000 + crop_to = 1.5 + desired_samples_len = sampling_rate*crop_to + Crop = RandomCrop(seconds=crop_to,sampling_rate=sampling_rate) + cropped_samples = Crop(samples) + + self.assertEqual(desired_samples_len, cropped_samples.size(-1)) + + + + + + + From c9a8e10b1cdebbcc555fd10d5ad36b6f223b9d97 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Tue, 29 Mar 2022 10:23:56 +0530 Subject: [PATCH 05/14] input sample checks --- .../augmentations/random_crop.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/torch_audiomentations/augmentations/random_crop.py b/torch_audiomentations/augmentations/random_crop.py index 642e8fa8..d75ac52e 100644 --- a/torch_audiomentations/augmentations/random_crop.py +++ b/torch_audiomentations/augmentations/random_crop.py @@ -1,9 +1,13 @@ import torch import typing +import warnings +from torch_audiomentations.utils.multichannel import is_multichannel +from ..core.transforms_interface import MultichannelAudioNotSupportedException class RandomCrop(torch.nn.Module): requires_sample_rate = True + supports_multichannel = True def __init__( self, @@ -21,6 +25,34 @@ def forward(self, samples, sampling_rate: typing.Optional[int] = None): if sample_rate is None: raise RuntimeError("sample_rate is required") + if len(samples) == 0: + warnings.warn( + "An empty samples tensor was passed to {}".format(self.__class__.__name__) + ) + return samples + + if len(samples.shape) != 3: + raise RuntimeError( + "torch-audiomentations expects input tensors to be three-dimensional, with" + " dimension ordering like [batch_size, num_channels, num_samples]. If your" + " audio is mono, you can use a shape like [batch_size, 1, num_samples]." + ) + + if is_multichannel(samples): + if samples.shape[1] > samples.shape[2]: + warnings.warn( + "Multichannel audio must have channels first, not channels last. In" + " other words, the shape must be (batch size, channels, samples), not" + " (batch_size, samples, channels)" + ) + if not self.supports_multichannel: + raise MultichannelAudioNotSupportedException( + "{} only supports mono audio, not multichannel audio".format( + self.__class__.__name__ + ) + ) + + if samples.shape[2] < self.num_samples: self.num_samples = samples.shape[2] raise RuntimeWarning("audio length less than cropping length") From 54d764a93c9a202187313bf34667824787d20f08 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Tue, 29 Mar 2022 10:27:14 +0530 Subject: [PATCH 06/14] remove unused var --- torch_audiomentations/augmentations/random_crop.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch_audiomentations/augmentations/random_crop.py b/torch_audiomentations/augmentations/random_crop.py index d75ac52e..f9970cf6 100644 --- a/torch_audiomentations/augmentations/random_crop.py +++ b/torch_audiomentations/augmentations/random_crop.py @@ -6,7 +6,6 @@ class RandomCrop(torch.nn.Module): - requires_sample_rate = True supports_multichannel = True def __init__( From 84916ec5ea630ec7ac25d2b5032fc7cb818d7eec Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Tue, 29 Mar 2022 12:49:10 +0530 Subject: [PATCH 07/14] remove device --- tests/test_random_crop.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_random_crop.py b/tests/test_random_crop.py index f8503535..71ed47c5 100644 --- a/tests/test_random_crop.py +++ b/tests/test_random_crop.py @@ -1,12 +1,11 @@ import unittest import torch from torch_audiomentations.augmentations.random_crop import RandomCrop -torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class TestRandomCrop(unittest.TestCase): def testcrop(self): - samples = torch.rand(size=(8, 2, 32000), dtype=torch.float32, device=torch_device) - 0.5 + samples = torch.rand(size=(8, 2, 32000), dtype=torch.float32) sampling_rate = 16000 crop_to = 1.5 desired_samples_len = sampling_rate*crop_to From 76a15874c29dc2b142c23318687d59c1e1a31644 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Tue, 29 Mar 2022 13:01:42 +0530 Subject: [PATCH 08/14] cuda test for random crop --- tests/test_random_crop.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/tests/test_random_crop.py b/tests/test_random_crop.py index 71ed47c5..e354b9e3 100644 --- a/tests/test_random_crop.py +++ b/tests/test_random_crop.py @@ -1,10 +1,12 @@ import unittest import torch +import pytest +import numpy as np from torch_audiomentations.augmentations.random_crop import RandomCrop class TestRandomCrop(unittest.TestCase): - def testcrop(self): + def test_crop(self): samples = torch.rand(size=(8, 2, 32000), dtype=torch.float32) sampling_rate = 16000 crop_to = 1.5 @@ -14,6 +16,31 @@ def testcrop(self): self.assertEqual(desired_samples_len, cropped_samples.size(-1)) + def test_crop_larger_cropto(self): + samples = torch.rand(size=(8, 2, 32000), dtype=torch.float32) + sampling_rate = 16000 + crop_to = 3 + Crop = RandomCrop(seconds=crop_to,sampling_rate=sampling_rate) + cropped_samples = Crop(samples) + + np.testing.assert_array_equal(samples, cropped_samples) + self.assertEqual(samples.size(-1), cropped_samples.size(-1)) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") + def test_crop_on_device_cuda(self): + + samples = torch.rand(size=(8, 2, 32000), dtype=torch.float32,device=torch.device('cuda')) + sampling_rate = 16000 + crop_to = 1.5 + desired_samples_len = sampling_rate*crop_to + Crop = RandomCrop(seconds=crop_to,sampling_rate=sampling_rate) + cropped_samples = Crop(samples) + + self.assertEqual(desired_samples_len, cropped_samples.size(-1)) + + + + From 21d5da291fc813b46d9e7a100b933b316dd26987 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Tue, 29 Mar 2022 13:02:27 +0530 Subject: [PATCH 09/14] change error to warning --- torch_audiomentations/augmentations/random_crop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_audiomentations/augmentations/random_crop.py b/torch_audiomentations/augmentations/random_crop.py index f9970cf6..ba39325d 100644 --- a/torch_audiomentations/augmentations/random_crop.py +++ b/torch_audiomentations/augmentations/random_crop.py @@ -53,8 +53,8 @@ def forward(self, samples, sampling_rate: typing.Optional[int] = None): if samples.shape[2] < self.num_samples: - self.num_samples = samples.shape[2] - raise RuntimeWarning("audio length less than cropping length") + warnings.warn("audio length less than cropping length") + return samples start_indices = torch.randint(0,samples.shape[2] - self.num_samples,(samples.shape[2],)) From e85ed302433ba78231268f0e919327f1480af7b2 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Tue, 29 Mar 2022 13:08:30 +0530 Subject: [PATCH 10/14] black formating --- tests/test_random_crop.py | 36 ++++++++++++++---------------------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/tests/test_random_crop.py b/tests/test_random_crop.py index e354b9e3..6f257068 100644 --- a/tests/test_random_crop.py +++ b/tests/test_random_crop.py @@ -4,46 +4,38 @@ import numpy as np from torch_audiomentations.augmentations.random_crop import RandomCrop + class TestRandomCrop(unittest.TestCase): - def test_crop(self): - samples = torch.rand(size=(8, 2, 32000), dtype=torch.float32) + samples = torch.rand(size=(8, 2, 32000), dtype=torch.float32) sampling_rate = 16000 crop_to = 1.5 - desired_samples_len = sampling_rate*crop_to - Crop = RandomCrop(seconds=crop_to,sampling_rate=sampling_rate) + desired_samples_len = sampling_rate * crop_to + Crop = RandomCrop(seconds=crop_to, sampling_rate=sampling_rate) cropped_samples = Crop(samples) - + self.assertEqual(desired_samples_len, cropped_samples.size(-1)) def test_crop_larger_cropto(self): - samples = torch.rand(size=(8, 2, 32000), dtype=torch.float32) + samples = torch.rand(size=(8, 2, 32000), dtype=torch.float32) sampling_rate = 16000 crop_to = 3 - Crop = RandomCrop(seconds=crop_to,sampling_rate=sampling_rate) + Crop = RandomCrop(seconds=crop_to, sampling_rate=sampling_rate) cropped_samples = Crop(samples) - + np.testing.assert_array_equal(samples, cropped_samples) self.assertEqual(samples.size(-1), cropped_samples.size(-1)) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") def test_crop_on_device_cuda(self): - samples = torch.rand(size=(8, 2, 32000), dtype=torch.float32,device=torch.device('cuda')) + samples = torch.rand( + size=(8, 2, 32000), dtype=torch.float32, device=torch.device("cuda") + ) sampling_rate = 16000 crop_to = 1.5 - desired_samples_len = sampling_rate*crop_to - Crop = RandomCrop(seconds=crop_to,sampling_rate=sampling_rate) + desired_samples_len = sampling_rate * crop_to + Crop = RandomCrop(seconds=crop_to, sampling_rate=sampling_rate) cropped_samples = Crop(samples) - - self.assertEqual(desired_samples_len, cropped_samples.size(-1)) - - - - - - - - - + self.assertEqual(desired_samples_len, cropped_samples.size(-1)) From 79cc1ffbf1b1be62a70c6f64142b629e64a14e62 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Tue, 29 Mar 2022 13:10:42 +0530 Subject: [PATCH 11/14] black formating --- .../augmentations/random_crop.py | 38 +++++++++---------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/torch_audiomentations/augmentations/random_crop.py b/torch_audiomentations/augmentations/random_crop.py index ba39325d..22863727 100644 --- a/torch_audiomentations/augmentations/random_crop.py +++ b/torch_audiomentations/augmentations/random_crop.py @@ -4,22 +4,20 @@ from torch_audiomentations.utils.multichannel import is_multichannel from ..core.transforms_interface import MultichannelAudioNotSupportedException + class RandomCrop(torch.nn.Module): + """Crop the audio to predefined length in seconds.""" + supports_multichannel = True - def __init__( - self, - seconds: float, - sampling_rate: int - ): - super(RandomCrop,self).__init__() + def __init__(self, seconds: float, sampling_rate: int): + super(RandomCrop, self).__init__() self.sampling_rate = sampling_rate self.num_samples = int(self.sampling_rate * seconds) - def forward(self, samples, sampling_rate: typing.Optional[int] = None): - + sample_rate = sampling_rate or self.sampling_rate if sample_rate is None: raise RuntimeError("sample_rate is required") @@ -51,22 +49,20 @@ def forward(self, samples, sampling_rate: typing.Optional[int] = None): ) ) - if samples.shape[2] < self.num_samples: warnings.warn("audio length less than cropping length") return samples - - - start_indices = torch.randint(0,samples.shape[2] - self.num_samples,(samples.shape[2],)) - samples_cropped = torch.empty((samples.shape[0],samples.shape[1],self.num_samples)) - for i,sample in enumerate(samples): - - samples_cropped[i] = sample.unsqueeze(0)[:,:,start_indices[i]:start_indices[i]+self.num_samples] - - return samples_cropped - - - + start_indices = torch.randint( + 0, samples.shape[2] - self.num_samples, (samples.shape[2],) + ) + samples_cropped = torch.empty( + (samples.shape[0], samples.shape[1], self.num_samples) + ) + for i, sample in enumerate(samples): + samples_cropped[i] = sample.unsqueeze(0)[ + :, :, start_indices[i] : start_indices[i] + self.num_samples + ] + return samples_cropped From c716806cd8172777a484794965b5550e7bcddb3e Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Wed, 30 Mar 2022 20:04:00 +0530 Subject: [PATCH 12/14] change argument --- tests/test_random_crop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_random_crop.py b/tests/test_random_crop.py index 6f257068..27db9d73 100644 --- a/tests/test_random_crop.py +++ b/tests/test_random_crop.py @@ -11,7 +11,7 @@ def test_crop(self): sampling_rate = 16000 crop_to = 1.5 desired_samples_len = sampling_rate * crop_to - Crop = RandomCrop(seconds=crop_to, sampling_rate=sampling_rate) + Crop = RandomCrop(max_length=crop_to, sampling_rate=sampling_rate) cropped_samples = Crop(samples) self.assertEqual(desired_samples_len, cropped_samples.size(-1)) @@ -20,7 +20,7 @@ def test_crop_larger_cropto(self): samples = torch.rand(size=(8, 2, 32000), dtype=torch.float32) sampling_rate = 16000 crop_to = 3 - Crop = RandomCrop(seconds=crop_to, sampling_rate=sampling_rate) + Crop = RandomCrop(max_length=crop_to, sampling_rate=sampling_rate) cropped_samples = Crop(samples) np.testing.assert_array_equal(samples, cropped_samples) @@ -35,7 +35,7 @@ def test_crop_on_device_cuda(self): sampling_rate = 16000 crop_to = 1.5 desired_samples_len = sampling_rate * crop_to - Crop = RandomCrop(seconds=crop_to, sampling_rate=sampling_rate) + Crop = RandomCrop(max_length=crop_to, sampling_rate=sampling_rate) cropped_samples = Crop(samples) self.assertEqual(desired_samples_len, cropped_samples.size(-1)) From 9be45ce5469c07d2f8edea4ac4ccdd1e3faacd53 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Wed, 30 Mar 2022 20:04:26 +0530 Subject: [PATCH 13/14] add unit argument --- .../augmentations/random_crop.py | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/torch_audiomentations/augmentations/random_crop.py b/torch_audiomentations/augmentations/random_crop.py index 22863727..082466d6 100644 --- a/torch_audiomentations/augmentations/random_crop.py +++ b/torch_audiomentations/augmentations/random_crop.py @@ -7,14 +7,30 @@ class RandomCrop(torch.nn.Module): - """Crop the audio to predefined length in seconds.""" + """Crop the audio to predefined length in max_length.""" supports_multichannel = True - def __init__(self, seconds: float, sampling_rate: int): + def __init__(self, + max_length: float, + sampling_rate: int, + max_length_unit:str = "seconds" + ): + """ + :param max_length: length to which samples are to be cropped. + :sampling_rate: sampling rate of input samples. + :max_length_unit: defines the unit of max_length. + "seconds": Number of seconds + "samples": Number of audio samples + """ super(RandomCrop, self).__init__() self.sampling_rate = sampling_rate - self.num_samples = int(self.sampling_rate * seconds) + if max_length_unit == "seconds": + self.num_samples = int(self.sampling_rate * max_length) + elif max_length_unit == "samples": + self.num_samples = int(max_length) + else: + raise ValueError('max_length must be "samples" or "seconds"') def forward(self, samples, sampling_rate: typing.Optional[int] = None): From 0acf9f181e4ed0b6f86d56eea1febd504e9d5fa3 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Thu, 31 Mar 2022 08:57:01 +0530 Subject: [PATCH 14/14] fix typo --- torch_audiomentations/augmentations/random_crop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_audiomentations/augmentations/random_crop.py b/torch_audiomentations/augmentations/random_crop.py index 082466d6..48e37680 100644 --- a/torch_audiomentations/augmentations/random_crop.py +++ b/torch_audiomentations/augmentations/random_crop.py @@ -30,7 +30,7 @@ def __init__(self, elif max_length_unit == "samples": self.num_samples = int(max_length) else: - raise ValueError('max_length must be "samples" or "seconds"') + raise ValueError('max_length_unit must be "samples" or "seconds"') def forward(self, samples, sampling_rate: typing.Optional[int] = None):