Skip to content

Commit

Permalink
Merge pull request #127 from shahules786/dev#97
Browse files Browse the repository at this point in the history
Dev#97
  • Loading branch information
iver56 authored Mar 31, 2022
2 parents ab6dbcd + 0acf9f1 commit 71ba564
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 0 deletions.
41 changes: 41 additions & 0 deletions tests/test_random_crop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import unittest
import torch
import pytest
import numpy as np
from torch_audiomentations.augmentations.random_crop import RandomCrop


class TestRandomCrop(unittest.TestCase):
def test_crop(self):
samples = torch.rand(size=(8, 2, 32000), dtype=torch.float32)
sampling_rate = 16000
crop_to = 1.5
desired_samples_len = sampling_rate * crop_to
Crop = RandomCrop(max_length=crop_to, sampling_rate=sampling_rate)
cropped_samples = Crop(samples)

self.assertEqual(desired_samples_len, cropped_samples.size(-1))

def test_crop_larger_cropto(self):
samples = torch.rand(size=(8, 2, 32000), dtype=torch.float32)
sampling_rate = 16000
crop_to = 3
Crop = RandomCrop(max_length=crop_to, sampling_rate=sampling_rate)
cropped_samples = Crop(samples)

np.testing.assert_array_equal(samples, cropped_samples)
self.assertEqual(samples.size(-1), cropped_samples.size(-1))

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA")
def test_crop_on_device_cuda(self):

samples = torch.rand(
size=(8, 2, 32000), dtype=torch.float32, device=torch.device("cuda")
)
sampling_rate = 16000
crop_to = 1.5
desired_samples_len = sampling_rate * crop_to
Crop = RandomCrop(max_length=crop_to, sampling_rate=sampling_rate)
cropped_samples = Crop(samples)

self.assertEqual(desired_samples_len, cropped_samples.size(-1))
84 changes: 84 additions & 0 deletions torch_audiomentations/augmentations/random_crop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import torch
import typing
import warnings
from torch_audiomentations.utils.multichannel import is_multichannel
from ..core.transforms_interface import MultichannelAudioNotSupportedException


class RandomCrop(torch.nn.Module):

"""Crop the audio to predefined length in max_length."""

supports_multichannel = True

def __init__(self,
max_length: float,
sampling_rate: int,
max_length_unit:str = "seconds"
):
"""
:param max_length: length to which samples are to be cropped.
:sampling_rate: sampling rate of input samples.
:max_length_unit: defines the unit of max_length.
"seconds": Number of seconds
"samples": Number of audio samples
"""
super(RandomCrop, self).__init__()
self.sampling_rate = sampling_rate
if max_length_unit == "seconds":
self.num_samples = int(self.sampling_rate * max_length)
elif max_length_unit == "samples":
self.num_samples = int(max_length)
else:
raise ValueError('max_length_unit must be "samples" or "seconds"')

def forward(self, samples, sampling_rate: typing.Optional[int] = None):

sample_rate = sampling_rate or self.sampling_rate
if sample_rate is None:
raise RuntimeError("sample_rate is required")

if len(samples) == 0:
warnings.warn(
"An empty samples tensor was passed to {}".format(self.__class__.__name__)
)
return samples

if len(samples.shape) != 3:
raise RuntimeError(
"torch-audiomentations expects input tensors to be three-dimensional, with"
" dimension ordering like [batch_size, num_channels, num_samples]. If your"
" audio is mono, you can use a shape like [batch_size, 1, num_samples]."
)

if is_multichannel(samples):
if samples.shape[1] > samples.shape[2]:
warnings.warn(
"Multichannel audio must have channels first, not channels last. In"
" other words, the shape must be (batch size, channels, samples), not"
" (batch_size, samples, channels)"
)
if not self.supports_multichannel:
raise MultichannelAudioNotSupportedException(
"{} only supports mono audio, not multichannel audio".format(
self.__class__.__name__
)
)

if samples.shape[2] < self.num_samples:
warnings.warn("audio length less than cropping length")
return samples

start_indices = torch.randint(
0, samples.shape[2] - self.num_samples, (samples.shape[2],)
)
samples_cropped = torch.empty(
(samples.shape[0], samples.shape[1], self.num_samples)
)
for i, sample in enumerate(samples):

samples_cropped[i] = sample.unsqueeze(0)[
:, :, start_indices[i] : start_indices[i] + self.num_samples
]

return samples_cropped

0 comments on commit 71ba564

Please sign in to comment.