Skip to content

Commit

Permalink
attempt to fix the repetition/hallucination issue identified in opena…
Browse files Browse the repository at this point in the history
…i#1046 (openai#1052)

* attempt to fix the repetition/hallucination issue identified in openai#1046

* zero-pad the audio instead of spectrogram

* formatting fix

* delete debug print
  • Loading branch information
jongwook authored and zackees committed May 5, 2023
1 parent 8ce59c5 commit 2ae5b6c
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 27 deletions.
23 changes: 17 additions & 6 deletions whisper/audio.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from functools import lru_cache
from typing import Union
from typing import Optional, Union

import ffmpeg
import numpy as np
Expand All @@ -15,10 +15,8 @@
N_MELS = 80
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
N_FRAMES = exact_div(
N_SAMPLES, HOP_LENGTH
) # 3000: number of frames in a mel spectrogram input
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input

N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
Expand Down Expand Up @@ -100,7 +98,10 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:


def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS
audio: Union[str, np.ndarray, torch.Tensor],
n_mels: int = N_MELS,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
):
"""
Compute the log-Mel spectrogram of
Expand All @@ -113,6 +114,12 @@ def log_mel_spectrogram(
n_mels: int
The number of Mel-frequency filters, only 80 is supported
padding: int
Number of zero samples to pad to the right
device: Optional[Union[str, torch.device]]
If given, the audio tensor is moved to this device before STFT
Returns
-------
torch.Tensor, shape = (80, n_frames)
Expand All @@ -123,6 +130,10 @@ def log_mel_spectrogram(
audio = load_audio(audio)
audio = torch.from_numpy(audio)

if device is not None:
audio = audio.to(device)
if padding > 0:
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
Expand Down
42 changes: 21 additions & 21 deletions whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
FRAMES_PER_SECOND,
HOP_LENGTH,
N_FRAMES,
N_SAMPLES,
SAMPLE_RATE,
log_mel_spectrogram,
pad_or_trim,
Expand Down Expand Up @@ -116,7 +117,9 @@ def transcribe(
if dtype == torch.float32:
decode_options["fp16"] = False

mel = log_mel_spectrogram(audio)
# Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
content_frames = mel.shape[-1] - N_FRAMES

if decode_options.get("language", None) is None:
if not model.is_multilingual:
Expand Down Expand Up @@ -212,14 +215,13 @@ def new_segment(
}

# show the progress bar when verbose is False (if True, transcribed text will be printed)
num_frames = mel.shape[-1]
with tqdm.tqdm(
total=num_frames, unit="frames", disable=verbose is not False
total=content_frames, unit="frames", disable=verbose is not False
) as pbar:
while seek < num_frames:
while seek < content_frames:
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
mel_segment = mel[:, seek:]
segment_size = min(mel_segment.shape[-1], N_FRAMES)
mel_segment = mel[:, seek : seek + N_FRAMES]
segment_size = min(N_FRAMES, content_frames - seek)
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)

Expand All @@ -246,20 +248,18 @@ def new_segment(
current_tokens = []

timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[
0
].add_(1)
if (
len(consecutive) > 0
): # if the output contains two consecutive timestamp tokens
if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [
False,
True,
]:
consecutive = consecutive.tolist() + [len(tokens)]
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]

consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
consecutive.add_(1)
if len(consecutive) > 0:
# if the output contains two consecutive timestamp tokens
slices = consecutive.tolist()
if single_timestamp_ending:
slices.append(len(tokens))

last_slice = 0
for current_slice in consecutive:
for current_slice in slices:
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_pos = (
sliced_tokens[0].item() - tokenizer.timestamp_begin
Expand All @@ -278,7 +278,7 @@ def new_segment(
current_tokens.append(sliced_tokens.tolist())
last_slice = current_slice

if ended_with_single_timestamp:
if single_timestamp_ending:
# single timestamp at the end means no speech after the last timestamp.
seek += segment_size
else:
Expand Down Expand Up @@ -329,7 +329,7 @@ def new_segment(
word_end_timestamps = [
w["end"] for s in current_segments for w in s["words"]
]
if len(consecutive) > 0 and len(word_end_timestamps) > 0:
if not single_timestamp_ending and len(word_end_timestamps) > 0:
seek_shift = round(
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
)
Expand All @@ -356,7 +356,7 @@ def new_segment(
)

# update progress bar
pbar.update(min(num_frames, seek) - previous_seek)
pbar.update(min(content_frames, seek) - previous_seek)

return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
Expand Down

0 comments on commit 2ae5b6c

Please sign in to comment.