Skip to content

Commit

Permalink
update test with slice
Browse files Browse the repository at this point in the history
  • Loading branch information
kamilakesbi committed Aug 19, 2024
1 parent 89b7143 commit 5b02249
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions tests/models/dac/test_modeling_dac.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,12 +457,11 @@ def test_integration_16khz(self):

def test_integration_24khz(self):
expected_rmse = 0.0039

expected_encoder_sums_dict = {
"loss": 30.0128,
"quantized_representation": 0.0104,
"audio_codes": 518.2788,
"projected_latents": -0.0051,

expected_encoder_output_dict = {
"quantized_representation": torch.tensor([0.9807, 2.8212, 5.2514, 2.7241, 1.0426]),
"audio_codes": torch.tensor([919, 919, 234, 777, 234]),
"projected_latents": torch.tensor([-4.7822, -5.0046, -4.5574, -5.0363, -5.4271]),
}
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

Expand All @@ -483,12 +482,15 @@ def test_integration_24khz(self):

with torch.no_grad():
encoder_outputs = model.encode(inputs["input_values"])

expected_encoder_sums = torch.tensor(list(expected_encoder_sums_dict.values()), dtype=torch.float32)
encoder_outputs_mean = torch.tensor([v.float().mean().cpu().item() for v in encoder_outputs.to_tuple()])

# make sure audio encoded codes are correct
self.assertTrue(torch.allclose(encoder_outputs_mean, expected_encoder_sums, atol=1e-3))

expected_quantized_representation = encoder_outputs['quantized_representation'][0,0,:5].cpu()
expected_audio_codes = encoder_outputs['audio_codes'][0,0,:5].cpu()
expected_projected_latents = encoder_outputs['projected_latents'][0,0,:5].cpu()

# make sure values are correct for audios slices
self.assertTrue(torch.allclose(expected_quantized_representation, expected_encoder_output_dict['quantized_representation'], atol=1e-3))
self.assertTrue(torch.allclose(expected_audio_codes, expected_encoder_output_dict['audio_codes'], atol=1e-3))
self.assertTrue(torch.allclose(expected_projected_latents, expected_encoder_output_dict['projected_latents'], atol=1e-3))

_, quantized_representation, _, _ = encoder_outputs.to_tuple()
input_values_dec = model.decode(quantized_representation)[0]
Expand Down

0 comments on commit 5b02249

Please sign in to comment.