Skip to content

Commit

Permalink
Add Silence Augmentation (#5476)
Browse files Browse the repository at this point in the history
* add silence augmentation

Signed-off-by: fayejf <[email protected]>

* reflect comment

Signed-off-by: fayejf <[email protected]>

* fix CI

Signed-off-by: fayejf <[email protected]>

Signed-off-by: fayejf <[email protected]>
  • Loading branch information
fayejf authored Nov 23, 2022
1 parent 27a4acc commit da953ae
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 2 deletions.
1 change: 1 addition & 0 deletions nemo/collections/asr/parts/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Perturbation,
RirAndNoisePerturbation,
ShiftPerturbation,
SilencePerturbation,
SpeedPerturbation,
TimeStretchPerturbation,
TranscodePerturbation,
Expand Down
42 changes: 41 additions & 1 deletion nemo/collections/asr/parts/preprocessing/perturb.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import random
import subprocess
from tempfile import NamedTemporaryFile
from typing import List, Optional, Union
from typing import Any, List, Optional, Union

import librosa
import numpy as np
Expand Down Expand Up @@ -275,6 +275,45 @@ def perturb(self, data):
data._samples = y_stretch


class SilencePerturbation(Perturbation):
"""
Applies random silence at the start and/or end of the audio.
Args:
min_start_silence_secs (float): Min start silence level in secs
max_start_silence_secs (float): Max start silence level in secs
min_end_silence_secs (float): Min end silence level in secs
max_end_silence_secs (float): Max end silence level in secs
rng: Random number generator
value: (float): value representing silence to be added to audio array.
"""

def __init__(
self,
min_start_silence_secs: float = 0,
max_start_silence_secs: float = 0,
min_end_silence_secs: float = 0,
max_end_silence_secs: float = 0,
rng: Optional[Any] = None,
value: float = 0,
):
self._min_start_silence_secs = min_start_silence_secs
self._max_start_silence_secs = max_start_silence_secs
self._min_end_silence_secs = min_end_silence_secs
self._max_end_silence_secs = max_end_silence_secs

self._rng = random.Random() if rng is None else rng
self._value = value

def perturb(self, data):
start_silence_len = self._rng.uniform(self._min_start_silence_secs, self._max_start_silence_secs)
end_silence_len = self._rng.uniform(self._min_end_silence_secs, self._max_end_silence_secs)
start = np.full((int(start_silence_len * data.sample_rate),), self._value)
end = np.full((int(end_silence_len * data.sample_rate),), self._value)

data._samples = np.concatenate([start, data._samples, end])


class GainPerturbation(Perturbation):
"""
Applies random gain to the audio.
Expand Down Expand Up @@ -779,6 +818,7 @@ def perturb(self, data):
"speed": SpeedPerturbation,
"time_stretch": TimeStretchPerturbation,
"gain": GainPerturbation,
"silence": SilencePerturbation,
"impulse": ImpulsePerturbation,
"shift": ShiftPerturbation,
"noise": NoisePerturbation,
Expand Down
26 changes: 25 additions & 1 deletion tests/collections/asr/test_preprocessing_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import pytest
import soundfile as sf

from nemo.collections.asr.parts.preprocessing.perturb import NoisePerturbation
from nemo.collections.asr.parts.preprocessing.perturb import NoisePerturbation, SilencePerturbation
from nemo.collections.asr.parts.preprocessing.segment import AudioSegment
from nemo.collections.asr.parts.utils.audio_utils import select_channels

Expand Down Expand Up @@ -177,3 +177,27 @@ def test_noise_perturb_channels(self, data_channels, noise_channels):
_ = perturber.perturb_with_input_noise(audio, noise)
with pytest.raises(ValueError):
_ = perturber.perturb_with_foreground_noise(audio, noise)

def test_silence_perturb(self):
"""Test loading a signal from a file and apply silence perturbation
"""
with tempfile.TemporaryDirectory() as test_dir:
# Prepare a wav file
audio_file = os.path.join(test_dir, 'audio.wav')
# samples is a one-dimensional vector for single-channel signal
samples = np.random.rand(self.num_samples)
sf.write(audio_file, samples, self.sample_rate, 'float')

dur = 2
perturber = SilencePerturbation(
min_start_silence_secs=dur,
max_start_silence_secs=dur,
min_end_silence_secs=dur,
max_end_silence_secs=dur,
)

audio = AudioSegment.from_file(audio_file)
ori_audio_len = len(audio._samples)
_ = perturber.perturb(audio)

assert len(audio._samples) == ori_audio_len + 2 * dur * self.sample_rate

0 comments on commit da953ae

Please sign in to comment.