Skip to content

Commit

Permalink
feat: add "output_type" argument
Browse files Browse the repository at this point in the history
  • Loading branch information
hbredin committed Apr 4, 2022
1 parent 7b4475d commit 5491792
Show file tree
Hide file tree
Showing 37 changed files with 336 additions and 151 deletions.
33 changes: 22 additions & 11 deletions tests/test_background_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,15 @@ def setUp(self):

self.bg_path = TEST_FIXTURES_DIR / "bg"
self.bg_short_path = TEST_FIXTURES_DIR / "bg_short"
self.bg_noise_transform_guaranteed = AddBackgroundNoise(self.bg_path, 20, p=1.0)
self.bg_noise_transform_guaranteed = AddBackgroundNoise(
self.bg_path, 20, p=1.0, output_type="dict"
)
self.bg_short_noise_transform_guaranteed = AddBackgroundNoise(
self.bg_short_path, 20, p=1.0
self.bg_short_path, 20, p=1.0, output_type="dict"
)
self.bg_noise_transform_no_guarantee = AddBackgroundNoise(
self.bg_path, 20, p=0.0, output_type="dict"
)
self.bg_noise_transform_no_guarantee = AddBackgroundNoise(self.bg_path, 20, p=0.0)

def test_background_noise_no_guarantee_with_single_tensor(self):
mixed_input = self.bg_noise_transform_no_guarantee(
Expand Down Expand Up @@ -118,7 +122,11 @@ def test_varying_snr_within_batch(self):
min_snr_in_db = 3
max_snr_in_db = 30
augment = AddBackgroundNoise(
self.bg_path, min_snr_in_db=min_snr_in_db, max_snr_in_db=max_snr_in_db, p=1.0
self.bg_path,
min_snr_in_db=min_snr_in_db,
max_snr_in_db=max_snr_in_db,
p=1.0,
output_type="dict",
)
augmented_audios = augment(self.input_audios, self.sample_rate).samples

Expand All @@ -142,13 +150,17 @@ def test_varying_snr_within_batch(self):
def test_invalid_params(self):
with self.assertRaises(ValueError):
augment = AddBackgroundNoise(
self.bg_path, min_snr_in_db=30, max_snr_in_db=3, p=1.0
self.bg_path, min_snr_in_db=30, max_snr_in_db=3, p=1.0, output_type="dict"
)

def test_min_equals_max(self):
desired_snr = 3.0
augment = AddBackgroundNoise(
self.bg_path, min_snr_in_db=desired_snr, max_snr_in_db=desired_snr, p=1.0
self.bg_path,
min_snr_in_db=desired_snr,
max_snr_in_db=desired_snr,
p=1.0,
output_type="dict",
)
augmented_audios = augment(self.input_audios, self.sample_rate).samples

Expand All @@ -171,11 +183,9 @@ def test_compatibility_of_resampled_length(self):
input_sample_rate = random.randint(1000, 5000)
bg_sample_rate = random.randint(1000, 5000)

noise = np.random.uniform(
low=-0.2,
high=0.2,
size=(bg_length,),
).astype(np.float32)
noise = np.random.uniform(low=-0.2, high=0.2, size=(bg_length,),).astype(
np.float32
)
tmp_dir = os.path.join(tempfile.gettempdir(), str(uuid.uuid4()))
try:
os.makedirs(tmp_dir)
Expand All @@ -192,6 +202,7 @@ def test_compatibility_of_resampled_length(self):
max_snr_in_db=6,
p=1.0,
sample_rate=input_sample_rate,
output_type="dict",
)
transform(input_audio)
except Exception:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_band_pass_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_band_pass_filter(self):
)
sample_rate = 16000

augment = BandPassFilter(p=1.0)
augment = BandPassFilter(p=1.0, output_type="dict")
for _ in range(20):
processed_samples = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
Expand Down
2 changes: 1 addition & 1 deletion tests/test_band_stop_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_band_reject_filter(self):
)
sample_rate = 16000

augment = BandStopFilter(p=1.0)
augment = BandStopFilter(p=1.0, output_type="dict")
processed_samples = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).samples.numpy()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
class TestBaseClass(unittest.TestCase):
def test_parameters(self):
# Test that we can access the parameters function of nn.Module
augment = PolarityInversion(p=1.0)
augment = PolarityInversion(p=1.0, output_type="dict")
params = augment.parameters()
assert isinstance(params, types.GeneratorType)

def test_ndim_check(self):
augment = PolarityInversion(p=1.0)
augment = PolarityInversion(p=1.0, output_type="dict")
# 1D tensor not allowed
with pytest.raises(RuntimeError):
augment(torch.tensor([1.0, 0.5, 0.25, 0.125], dtype=torch.float32))
Expand Down
12 changes: 8 additions & 4 deletions tests/test_colored_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@ def setUp(self):
).unsqueeze(0)

self.input_audios = torch.cat([self.input_audio] * self.batch_size, dim=0)
self.cl_noise_transform_guaranteed = AddColoredNoise(20, p=1.0)
self.cl_noise_transform_no_guarantee = AddColoredNoise(20, p=0.0)
self.cl_noise_transform_guaranteed = AddColoredNoise(
20, p=1.0, output_type="dict"
)
self.cl_noise_transform_no_guarantee = AddColoredNoise(
20, p=0.0, output_type="dict"
)

def test_colored_noise_no_guarantee_with_single_tensor(self):
mixed_input = self.cl_noise_transform_no_guarantee(
Expand Down Expand Up @@ -76,6 +80,6 @@ def test_colored_noise_guaranteed_with_batched_tensor(self):

def test_invalid_params(self):
with self.assertRaises(ValueError):
AddColoredNoise(min_snr_in_db=30, max_snr_in_db=3, p=1.0)
AddColoredNoise(min_snr_in_db=30, max_snr_in_db=3, p=1.0, output_type="dict")
with self.assertRaises(ValueError):
AddColoredNoise(min_f_decay=2, max_f_decay=1, p=1.0)
AddColoredNoise(min_f_decay=2, max_f_decay=1, p=1.0, output_type="dict")
23 changes: 13 additions & 10 deletions tests/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def test_compose(self):
[
Gain(min_gain_in_db=-6.000001, max_gain_in_db=-6, p=1.0),
PolarityInversion(p=1.0),
]
],
output_type="dict",
)
processed_samples = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
Expand All @@ -38,7 +39,9 @@ def test_compose_with_torchaudio_transform(self):
samples = np.array([[[1.0, 0.5, -0.25, -0.125, 0.0]]], dtype=np.float32)
sample_rate = 16000

augment = Compose([Vol(gain=-6, gain_type="db"), PolarityInversion(p=1.0)])
augment = Compose(
[Vol(gain=-6, gain_type="db"), PolarityInversion(p=1.0),], output_type="dict"
)
processed_samples = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).samples.numpy()
Expand All @@ -61,6 +64,7 @@ def test_compose_with_p_zero(self):
PolarityInversion(p=1.0),
],
p=0.0,
output_type="dict",
)
processed_samples = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
Expand All @@ -75,9 +79,10 @@ def test_freeze_and_unfreeze_parameters(self):

augment = Compose(
transforms=[
Gain(min_gain_in_db=-16.000001, max_gain_in_db=-2, p=1.0),
Gain(min_gain_in_db=-16.000001, max_gain_in_db=-2, p=1.0,),
PolarityInversion(p=1.0),
]
],
output_type="dict",
)

processed_samples1 = augment(
Expand Down Expand Up @@ -106,6 +111,7 @@ def test_shuffle(self):
PeakNormalization(p=1.0),
],
shuffle=True,
output_type="dict",
)
num_peak_normalization_last = 0
num_gain_last = 0
Expand All @@ -126,14 +132,11 @@ def test_shuffle(self):
self.assertGreater(num_gain_last, 10)

def test_supported_modes_property(self):
augment = Compose(
transforms=[
PeakNormalization(p=1.0),
],
)
augment = Compose(transforms=[PeakNormalization(p=1.0),], output_type="dict")
assert augment.supported_modes == {"per_batch", "per_example", "per_channel"}

augment = Compose(
transforms=[PeakNormalization(p=1.0), ShuffleChannels(p=1.0)],
transforms=[PeakNormalization(p=1.0,), ShuffleChannels(p=1.0,),],
output_type="dict",
)
assert augment.supported_modes == {"per_example"}
26 changes: 17 additions & 9 deletions tests/test_differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,34 @@
"augment",
[
# Differentiable transforms:
AddBackgroundNoise(BG_NOISE_PATH, 20, p=1.0),
ApplyImpulseResponse(IR_PATH, p=1.0),
AddBackgroundNoise(BG_NOISE_PATH, 20, p=1.0, output_type="dict"),
ApplyImpulseResponse(IR_PATH, p=1.0, output_type="dict"),
Compose(
transforms=[
Gain(min_gain_in_db=-15.0, max_gain_in_db=5.0, p=1.0),
PolarityInversion(p=1.0),
]
],
output_type="dict",
),
Gain(min_gain_in_db=-6.000001, max_gain_in_db=-6, p=1.0),
PolarityInversion(p=1.0),
Shift(p=1.0),
Gain(min_gain_in_db=-6.000001, max_gain_in_db=-6, p=1.0, output_type="dict"),
PolarityInversion(p=1.0, output_type="dict"),
Shift(p=1.0, output_type="dict"),
# Non-differentiable transforms:
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation:
# [torch.DoubleTensor [1, 1, 5]], which is output 0 of IndexBackward, is at version 1; expected version 0 instead.
# Hint: enable anomaly detection to find the operation that failed to compute its gradient,
# with torch.autograd.set_detect_anomaly(True).
pytest.param(HighPassFilter(p=1.0), marks=pytest.mark.skip("Not differentiable")),
pytest.param(LowPassFilter(p=1.0), marks=pytest.mark.skip("Not differentiable")),
pytest.param(
PeakNormalization(p=1.0), marks=pytest.mark.skip("Not differentiable")
HighPassFilter(p=1.0, output_type="dict"),
marks=pytest.mark.skip("Not differentiable"),
),
pytest.param(
LowPassFilter(p=1.0, output_type="dict"),
marks=pytest.mark.skip("Not differentiable"),
),
pytest.param(
PeakNormalization(p=1.0, output_type="dict"),
marks=pytest.mark.skip("Not differentiable"),
),
],
)
Expand Down
Loading

0 comments on commit 5491792

Please sign in to comment.