Skip to content

Commit

Permalink
Merge pull request #529 from MahmoudAshraf97/main
Browse files Browse the repository at this point in the history
  • Loading branch information
m-bain authored Oct 16, 2023
2 parents a150df4 + b69956d commit 66808f6
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions whisperx/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ def align(
aligned_segments.append(aligned_seg)
continue

if t1 >= MAX_DURATION or t2 - t1 < 0.02:
print("Failed to align segment: original start time longer than audio duration, skipping...")
if t1 >= MAX_DURATION:
print(f'Failed to align segment ("{segment["text"]}"): original start time longer than audio duration, skipping...')
aligned_segments.append(aligned_seg)
continue

Expand All @@ -207,17 +207,17 @@ def align(

# TODO: Probably can get some speedup gain with batched inference here
waveform_segment = audio[:, f1:f2]

# Handle the minimum input length for wav2vec2 models
if waveform_segment.shape[-1] < 400:
lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device)
waveform_segment = torch.nn.functional.pad(
waveform_segment, (0, 400 - waveform_segment.shape[-1])
)
else:
lengths = None

with torch.inference_mode():
if model_type == "torchaudio":
# Handle the minimum input length for torchaudio wav2vec2 models
if waveform_segment.shape[-1] < 400:
lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device)
waveform_segment = torch.nn.functional.pad(
waveform_segment, (0, 400 - waveform_segment.shape[-1])
)
else:
lengths = None
emissions, _ = model(waveform_segment.to(device), lengths=lengths)
elif model_type == "huggingface":
emissions = model(waveform_segment.to(device)).logits
Expand Down

0 comments on commit 66808f6

Please sign in to comment.