From 02c032377758244961f95f33152132cf1cf48879 Mon Sep 17 00:00:00 2001 From: Mahmoud Ashraf Date: Sun, 15 Oct 2023 16:25:15 +0300 Subject: [PATCH 1/2] fix --- whisperx/alignment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 7f3d5867..874502b8 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -194,8 +194,8 @@ def align( aligned_segments.append(aligned_seg) continue - if t1 >= MAX_DURATION or t2 - t1 < 0.02: - print("Failed to align segment: original start time longer than audio duration, skipping...") + if t1 >= MAX_DURATION: + print(f'Failed to align segment ("{segment["text"]}"): original start time longer than audio duration, skipping...') aligned_segments.append(aligned_seg) continue From b69956d725ce70794a79cb32a891ba1f2128f6db Mon Sep 17 00:00:00 2001 From: Mahmoud Ashraf Date: Mon, 16 Oct 2023 20:43:37 +0300 Subject: [PATCH 2/2] . --- whisperx/alignment.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 874502b8..68465f93 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -207,17 +207,17 @@ def align( # TODO: Probably can get some speedup gain with batched inference here waveform_segment = audio[:, f1:f2] - + # Handle the minimum input length for wav2vec2 models + if waveform_segment.shape[-1] < 400: + lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device) + waveform_segment = torch.nn.functional.pad( + waveform_segment, (0, 400 - waveform_segment.shape[-1]) + ) + else: + lengths = None + with torch.inference_mode(): if model_type == "torchaudio": - # Handle the minimum input length for torchaudio wav2vec2 models - if waveform_segment.shape[-1] < 400: - lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device) - waveform_segment = torch.nn.functional.pad( - waveform_segment, (0, 400 - waveform_segment.shape[-1]) - ) - else: - lengths = None emissions, _ = model(waveform_segment.to(device), lengths=lengths) elif model_type == "huggingface": emissions = model(waveform_segment.to(device)).logits