Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip silence around hallucinations #1838

Merged
merged 4 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions whisper/timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def add_word_timestamps(
word_durations = np.array([t.end - t.start for t in alignment])
word_durations = word_durations[word_durations.nonzero()]
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
median_duration = min(0.7, float(median_duration))
max_duration = median_duration * 2

# hack: truncate long words at sentence boundaries.
Expand Down
151 changes: 135 additions & 16 deletions whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import traceback
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import numpy as np
import torch
Expand All @@ -23,6 +23,7 @@
from .utils import (
exact_div,
format_timestamp,
get_end,
get_writer,
make_safe,
optional_float,
Expand All @@ -48,6 +49,8 @@ def transcribe(
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,,!!??::”)]}、",
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None,
**decode_options,
):
"""
Expand Down Expand Up @@ -102,6 +105,14 @@ def transcribe(
decode_options: dict
Keyword arguments to construct `DecodingOptions` instances

clip_timestamps: Union[str, List[float]]
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
The last end timestamp defaults to the end of the file.

hallucination_silence_threshold: Optional[float]
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
when a possible hallucination is detected

Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
Expand All @@ -121,6 +132,7 @@ def transcribe(
# Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
content_frames = mel.shape[-1] - N_FRAMES
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)

if decode_options.get("language", None) is None:
if not model.is_multilingual:
Expand All @@ -147,6 +159,19 @@ def transcribe(
task=task,
)

if isinstance(clip_timestamps, str):
clip_timestamps = [
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
]
seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
if len(seek_points) == 0:
seek_points.append(0)
if len(seek_points) % 2 == 1:
seek_points.append(content_frames)
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))

punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"

if word_timestamps and task == "translate":
warnings.warn("Word-level timestamps on translations may not be reliable.")

Expand Down Expand Up @@ -190,7 +215,8 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:

return decode_result

seek = 0
clip_idx = 0
seek = seek_clips[clip_idx][0]
input_stride = exact_div(
N_FRAMES, model.dims.n_audio_ctx
) # mel frames per output token: 2
Expand Down Expand Up @@ -229,10 +255,23 @@ def new_segment(
total=content_frames, unit="frames", disable=verbose is not False
) as pbar:
last_speech_timestamp = 0.0
while seek < content_frames:
# NOTE: This loop is obscurely flattened to make the diff readable.
# A later commit should turn this into a simpler nested loop.
# for seek_clip_start, seek_clip_end in seek_clips:
# while seek < seek_clip_end
while clip_idx < len(seek_clips):
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
if seek < seek_clip_start:
seek = seek_clip_start
if seek >= seek_clip_end:
clip_idx += 1
if clip_idx < len(seek_clips):
seek = seek_clips[clip_idx][0]
continue
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
mel_segment = mel[:, seek : seek + N_FRAMES]
segment_size = min(N_FRAMES, content_frames - seek)
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
mel_segment = mel[:, seek : seek + segment_size]
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)

Expand All @@ -257,6 +296,30 @@ def new_segment(
previous_seek = seek
current_segments = []

# anomalous words are very long/short/improbable
def word_anomaly_score(word: dict) -> float:
probability = word.get("probability", 0.0)
duration = word["end"] - word["start"]
score = 0.0
if probability < 0.15:
score += 1.0
if duration < 0.133:
score += (0.133 - duration) * 15
if duration > 2.0:
score += duration - 2.0
return score

def is_segment_anomaly(segment: Optional[dict]) -> bool:
if segment is None or not segment["words"]:
return False
words = [w for w in segment["words"] if w["word"] not in punctuation]
words = words[:8]
score = sum(word_anomaly_score(w) for w in words)
return score >= 3 or score + 0.01 >= len(words)

def next_words_segment(segments: List[dict]) -> Optional[dict]:
return next((s for s in segments if s["words"]), None)

timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]

Expand Down Expand Up @@ -330,17 +393,71 @@ def new_segment(
append_punctuations=append_punctuations,
last_speech_timestamp=last_speech_timestamp,
)
word_end_timestamps = [
w["end"] for s in current_segments for w in s["words"]
]
if len(word_end_timestamps) > 0:
last_speech_timestamp = word_end_timestamps[-1]
if not single_timestamp_ending and len(word_end_timestamps) > 0:
seek_shift = round(
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
)
if seek_shift > 0:
seek = previous_seek + seek_shift

if not single_timestamp_ending:
last_word_end = get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
seek = round(last_word_end * FRAMES_PER_SECOND)

# skip silence before possible hallucinations
if hallucination_silence_threshold is not None:
threshold = hallucination_silence_threshold
if not single_timestamp_ending:
last_word_end = get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
remaining_duration = window_end_time - last_word_end
if remaining_duration > threshold:
seek = round(last_word_end * FRAMES_PER_SECOND)
else:
seek = previous_seek + segment_size

# if first segment might be a hallucination, skip leading silence
first_segment = next_words_segment(current_segments)
if first_segment is not None and is_segment_anomaly(first_segment):
gap = first_segment["start"] - time_offset
if gap > threshold:
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
continue

# skip silence before any possible hallucination that is surrounded
# by silence or more hallucinations
hal_last_end = last_speech_timestamp
for si in range(len(current_segments)):
segment = current_segments[si]
if not segment["words"]:
continue
if is_segment_anomaly(segment):
next_segment = next_words_segment(
current_segments[si + 1 :]
)
if next_segment is not None:
hal_next_start = next_segment["words"][0]["start"]
else:
hal_next_start = time_offset + segment_duration
silence_before = (
segment["start"] - hal_last_end > threshold
or segment["start"] < threshold
or segment["start"] - time_offset < 2.0
)
silence_after = (
hal_next_start - segment["end"] > threshold
or is_segment_anomaly(next_segment)
or window_end_time - segment["end"] < 2.0
)
if silence_before and silence_after:
seek = round(
max(time_offset + 1, segment["start"])
* FRAMES_PER_SECOND
)
if content_duration - segment["end"] < threshold:
seek = content_frames
current_segments[si:] = []
break
hal_last_end = segment["end"]

last_word_end = get_end(current_segments)
if last_word_end is not None:
last_speech_timestamp = last_word_end

if verbose:
for segment in current_segments:
Expand Down Expand Up @@ -427,6 +544,8 @@ def valid_model_name(name):
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
# fmt: on

args = parser.parse_args().__dict__
Expand Down
20 changes: 17 additions & 3 deletions whisper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
import sys
import zlib
from typing import Callable, Optional, TextIO
from typing import Callable, List, Optional, TextIO

system_encoding = sys.getdefaultencoding()

Expand Down Expand Up @@ -68,6 +68,20 @@ def format_timestamp(
)


def get_start(segments: List[dict]) -> Optional[float]:
return next(
(w["start"] for s in segments for w in s["words"]),
segments[0]["start"] if segments else None,
)


def get_end(segments: List[dict]) -> Optional[float]:
return next(
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
segments[-1]["end"] if segments else None,
)


class ResultWriter:
extension: str

Expand Down Expand Up @@ -129,8 +143,8 @@ def iterate_subtitles():
line_len = 0
line_count = 1
# the next subtitle to yield (a list of word timings with whitespace)
subtitle: list[dict] = []
last = result["segments"][0]["words"][0]["start"]
subtitle: List[dict] = []
last: float = get_start(result["segments"]) or 0.0
for segment in result["segments"]:
chunk_index = 0
words_count = max_words_per_line
Expand Down