From d25249795ff2fb68ab9ff5de774b6cabe8c02ab2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Thu, 31 Mar 2022 16:57:02 +0200 Subject: [PATCH] feat: add targets-related tests in Mix --- tests/test_mix.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/test_mix.py b/tests/test_mix.py index 02fa92a1..da95a7f3 100644 --- a/tests/test_mix.py +++ b/tests/test_mix.py @@ -33,8 +33,20 @@ def setUp(self): common_num_samples = min(self.guitar.shape[-1], self.noise.shape[-1]) self.guitar = self.guitar[:, :, :common_num_samples] + + self.guitar_target = torch.zeros( + (1, 1, common_num_samples // 7, 2), dtype=torch.int64 + ) + self.guitar_target[:, :, :, 0] = 1 + self.noise = self.noise[:, :, :common_num_samples] + self.noise_target = torch.zeros( + (1, 1, common_num_samples // 7, 2), dtype=torch.int64 + ) + self.noise_target[:, :, :, 1] = 1 + self.input_audios = torch.cat([self.guitar, self.noise], dim=0) + self.input_targets = torch.cat([self.guitar_target, self.noise_target], dim=0) def test_varying_snr_within_batch(self): min_snr_in_db = 3 @@ -54,3 +66,32 @@ def test_varying_snr_within_batch(self): self.assertGreaterEqual(snr_in_db, min_snr_in_db) self.assertLessEqual(snr_in_db, max_snr_in_db) + def test_targets_union(self): + augment = Mix(p=1.0, mix_target="union") + mixtures = augment( + samples=self.input_audios, + sample_rate=self.sample_rate, + targets=self.input_targets, + ) + mixed_targets = mixtures.targets + + # check guitar target is still active in first (guitar) sample + self.assertTrue( + torch.equal(mixed_targets[0, :, :, 0], self.input_targets[0, :, :, 0]) + ) + # check noise target is still active in second (noise) sample + self.assertTrue( + torch.equal(mixed_targets[1, :, :, 1], self.input_targets[1, :, :, 1]) + ) + + def test_targets_original(self): + augment = Mix(p=1.0, mix_target="original") + mixtures = augment( + samples=self.input_audios, + sample_rate=self.sample_rate, + targets=self.input_targets, + ) + mixed_targets = mixtures.targets + + self.assertTrue(torch.equal(mixed_targets, self.input_targets)) +