-
Notifications
You must be signed in to change notification settings - Fork 89
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #127 from shahules786/dev#97
Dev#97
- Loading branch information
Showing
2 changed files
with
125 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import unittest | ||
import torch | ||
import pytest | ||
import numpy as np | ||
from torch_audiomentations.augmentations.random_crop import RandomCrop | ||
|
||
|
||
class TestRandomCrop(unittest.TestCase): | ||
def test_crop(self): | ||
samples = torch.rand(size=(8, 2, 32000), dtype=torch.float32) | ||
sampling_rate = 16000 | ||
crop_to = 1.5 | ||
desired_samples_len = sampling_rate * crop_to | ||
Crop = RandomCrop(max_length=crop_to, sampling_rate=sampling_rate) | ||
cropped_samples = Crop(samples) | ||
|
||
self.assertEqual(desired_samples_len, cropped_samples.size(-1)) | ||
|
||
def test_crop_larger_cropto(self): | ||
samples = torch.rand(size=(8, 2, 32000), dtype=torch.float32) | ||
sampling_rate = 16000 | ||
crop_to = 3 | ||
Crop = RandomCrop(max_length=crop_to, sampling_rate=sampling_rate) | ||
cropped_samples = Crop(samples) | ||
|
||
np.testing.assert_array_equal(samples, cropped_samples) | ||
self.assertEqual(samples.size(-1), cropped_samples.size(-1)) | ||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") | ||
def test_crop_on_device_cuda(self): | ||
|
||
samples = torch.rand( | ||
size=(8, 2, 32000), dtype=torch.float32, device=torch.device("cuda") | ||
) | ||
sampling_rate = 16000 | ||
crop_to = 1.5 | ||
desired_samples_len = sampling_rate * crop_to | ||
Crop = RandomCrop(max_length=crop_to, sampling_rate=sampling_rate) | ||
cropped_samples = Crop(samples) | ||
|
||
self.assertEqual(desired_samples_len, cropped_samples.size(-1)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import torch | ||
import typing | ||
import warnings | ||
from torch_audiomentations.utils.multichannel import is_multichannel | ||
from ..core.transforms_interface import MultichannelAudioNotSupportedException | ||
|
||
|
||
class RandomCrop(torch.nn.Module): | ||
|
||
"""Crop the audio to predefined length in max_length.""" | ||
|
||
supports_multichannel = True | ||
|
||
def __init__(self, | ||
max_length: float, | ||
sampling_rate: int, | ||
max_length_unit:str = "seconds" | ||
): | ||
""" | ||
:param max_length: length to which samples are to be cropped. | ||
:sampling_rate: sampling rate of input samples. | ||
:max_length_unit: defines the unit of max_length. | ||
"seconds": Number of seconds | ||
"samples": Number of audio samples | ||
""" | ||
super(RandomCrop, self).__init__() | ||
self.sampling_rate = sampling_rate | ||
if max_length_unit == "seconds": | ||
self.num_samples = int(self.sampling_rate * max_length) | ||
elif max_length_unit == "samples": | ||
self.num_samples = int(max_length) | ||
else: | ||
raise ValueError('max_length_unit must be "samples" or "seconds"') | ||
|
||
def forward(self, samples, sampling_rate: typing.Optional[int] = None): | ||
|
||
sample_rate = sampling_rate or self.sampling_rate | ||
if sample_rate is None: | ||
raise RuntimeError("sample_rate is required") | ||
|
||
if len(samples) == 0: | ||
warnings.warn( | ||
"An empty samples tensor was passed to {}".format(self.__class__.__name__) | ||
) | ||
return samples | ||
|
||
if len(samples.shape) != 3: | ||
raise RuntimeError( | ||
"torch-audiomentations expects input tensors to be three-dimensional, with" | ||
" dimension ordering like [batch_size, num_channels, num_samples]. If your" | ||
" audio is mono, you can use a shape like [batch_size, 1, num_samples]." | ||
) | ||
|
||
if is_multichannel(samples): | ||
if samples.shape[1] > samples.shape[2]: | ||
warnings.warn( | ||
"Multichannel audio must have channels first, not channels last. In" | ||
" other words, the shape must be (batch size, channels, samples), not" | ||
" (batch_size, samples, channels)" | ||
) | ||
if not self.supports_multichannel: | ||
raise MultichannelAudioNotSupportedException( | ||
"{} only supports mono audio, not multichannel audio".format( | ||
self.__class__.__name__ | ||
) | ||
) | ||
|
||
if samples.shape[2] < self.num_samples: | ||
warnings.warn("audio length less than cropping length") | ||
return samples | ||
|
||
start_indices = torch.randint( | ||
0, samples.shape[2] - self.num_samples, (samples.shape[2],) | ||
) | ||
samples_cropped = torch.empty( | ||
(samples.shape[0], samples.shape[1], self.num_samples) | ||
) | ||
for i, sample in enumerate(samples): | ||
|
||
samples_cropped[i] = sample.unsqueeze(0)[ | ||
:, :, start_indices[i] : start_indices[i] + self.num_samples | ||
] | ||
|
||
return samples_cropped |