diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 7f3d5867..68465f93 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -194,8 +194,8 @@ def align( aligned_segments.append(aligned_seg) continue - if t1 >= MAX_DURATION or t2 - t1 < 0.02: - print("Failed to align segment: original start time longer than audio duration, skipping...") + if t1 >= MAX_DURATION: + print(f'Failed to align segment ("{segment["text"]}"): original start time longer than audio duration, skipping...') aligned_segments.append(aligned_seg) continue @@ -207,17 +207,17 @@ def align( # TODO: Probably can get some speedup gain with batched inference here waveform_segment = audio[:, f1:f2] - + # Handle the minimum input length for wav2vec2 models + if waveform_segment.shape[-1] < 400: + lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device) + waveform_segment = torch.nn.functional.pad( + waveform_segment, (0, 400 - waveform_segment.shape[-1]) + ) + else: + lengths = None + with torch.inference_mode(): if model_type == "torchaudio": - # Handle the minimum input length for torchaudio wav2vec2 models - if waveform_segment.shape[-1] < 400: - lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device) - waveform_segment = torch.nn.functional.pad( - waveform_segment, (0, 400 - waveform_segment.shape[-1]) - ) - else: - lengths = None emissions, _ = model(waveform_segment.to(device), lengths=lengths) elif model_type == "huggingface": emissions = model(waveform_segment.to(device)).logits