diff --git a/Dockerfile b/Dockerfile index 895088b..6cc1ff3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -18,7 +18,10 @@ COPY setup.py /usr/src/app/setup.py COPY whisper_timestamped /usr/src/app/whisper_timestamped # Install -RUN cd /usr/src/app/ && pip3 install ".[dev]" && pip3 install ".[vad]" +RUN cd /usr/src/app/ && pip3 install ".[dev]" +RUN cd /usr/src/app/ && pip3 install ".[vad_silero]" +RUN cd /usr/src/app/ && pip3 install ".[vad_auditok]" +RUN cd /usr/src/app/ && pip3 install ".[test]" # Cleanup RUN rm -R /usr/src/app/requirements.txt /usr/src/app/setup.py /usr/src/app/whisper_timestamped diff --git a/README.md b/README.md index 7a29cbb..087ffd8 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,9 @@ Multilingual Automatic Speech Recognition with word-level timestamps and confide * [Plotting word alignment](#plotting-word-alignment) * [Example output](#example-output) * [Options that may improve results](#options-that-may-improve-results) + * [Accurate Whisper transcription](#accurate-whisper-transcription) + * [Running Voice Activity Detection (VAD) before sending to Whisper](#running-voice-activity-detection-vad-before-sending-to-whisper) + * [Detecting disfluencies](#detecting-disfluencies) * [Acknowlegment](#acknowlegment) * [Citations](#citations) @@ -32,7 +35,9 @@ The approach is based on Dynamic Time Warping (DTW) applied to cross-attention w `whisper-timestamped` is an extension of the [`openai-whisper`](https://pypi.org/project/whisper-openai/) Python package and is meant to be compatible with any version of `openai-whisper`. It provides more efficient/accurate word timestamps, along with those additional features: -* Voice Activity Detection (VAD) can be run before applying Whisper model, to avoid hallucinations due to errors in the training data (for instance, predicting "Thanks you for watching!" on pure silence). +* Voice Activity Detection (VAD) can be run before applying Whisper model, + to avoid hallucinations due to errors in the training data (for instance, predicting "Thanks you for watching!" on pure silence). + Several VAD methods are available: silero (default), auditok, auditok:v3.1 * When the language is not specified, the language probabilities are provided among the outputs. ### Notes on other approaches @@ -55,7 +60,7 @@ Requirements: You can install `whisper-timestamped` either by using pip: ```bash -pip3 install git+https://github.com/linto-ai/whisper-timestamped +pip3 install whisper-timestamped ``` or by cloning this repository and running installation: @@ -327,6 +332,27 @@ results = whisper_timestamped.transcribe(model, audio, vad=True, ...) whisper_timestamped --vad True ... ``` +By default, the VAD method used is [silero](https://github.com/snakers4/silero-vad). +But other methods are available, such as earlier versions of silero, or [auditok](https://github.com/amsehili/auditok). +Those methods were introduced because latest versions of silero VAD can have a lot of false alarms on some audios (speech detected on silence). +* In Python: +```python +results = whisper_timestamped.transcribe(model, audio, vad="silero:v3.1", ...) +results = whisper_timestamped.transcribe(model, audio, vad="auditok", ...) +``` +* On the command line: +```bash +whisper_timestamped --vad silero:v3.1 ... +whisper_timestamped --vad auditok ... +``` + +In order to watch the VAD results, you can use the `--plot` option of the `whisper_timestamped` CLI, +or the `plot_word_alignment` option of the `whisper_timestamped.transcribe()` Python function. +It will show the VAD results on the input audio signal as following (x-axis is time in seconds): +| **vad="silero:v4.0"** | **vad="silero:v3.1"** | **vad="auditok"** | +| :---: | :---: | :---: | +| ![Example VAD](figs/VAD_silero_v4.0.png) | ![Example VAD](figs/VAD_silero_v3.1.png) | ![Example VAD](figs/VAD_auditok.png) | + #### Detecting disfluencies Whisper models tend to remove speech disfluencies (filler words, hesitations, repetitions, etc.). Without precautions, the disfluencies that are not transcribed will affect the timestamp of the following word: the timestamp of the beginning of the word will actually be the timestamp of the beginning of the disfluencies. `whisper-timestamped` can have some heuristics to avoid this. diff --git a/figs/VAD_auditok.png b/figs/VAD_auditok.png new file mode 100644 index 0000000..9dbaa4a Binary files /dev/null and b/figs/VAD_auditok.png differ diff --git a/figs/VAD_silero_v3.1.png b/figs/VAD_silero_v3.1.png new file mode 100644 index 0000000..5216ef1 Binary files /dev/null and b/figs/VAD_silero_v3.1.png differ diff --git a/figs/VAD_silero_v4.0.png b/figs/VAD_silero_v4.0.png new file mode 100644 index 0000000..3b516f9 Binary files /dev/null and b/figs/VAD_silero_v4.0.png differ diff --git a/setup.py b/setup.py index 96758e3..44db60e 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ name="whisper-timestamped", py_modules=["whisper_timestamped"], version=version, - description="Add to OpenAI Whisper the capability to give word timestamps", + description="OpenAI Whisper ASR with accurate word timestamps, language detection confidence, several options of VAD, and more.", python_requires=">=3.7", author="Jeronymous", url="https://github.com/linto-ai/whisper-timestamped", @@ -37,7 +37,9 @@ }, include_package_data=True, extras_require={ - 'dev': ['matplotlib', 'jsonschema', 'transformers'], - 'vad': ['onnxruntime', 'torchaudio'], + 'dev': ['matplotlib', 'transformers'], + 'vad_silero': ['onnxruntime', 'torchaudio'], + 'vad_auditok': ['auditok'], + 'test': ['jsonschema'], }, ) diff --git a/tests/expected/verbose/vad_auditok_words.wav.stdout b/tests/expected/verbose/vad_auditok_words.wav.stdout new file mode 100644 index 0000000..f0d2d99 --- /dev/null +++ b/tests/expected/verbose/vad_auditok_words.wav.stdout @@ -0,0 +1,8 @@ +[00:00.750 --> 00:01.470] settlement, +[00:02.950 --> 00:03.670] Kentucky, +[00:05.770 --> 00:06.290] causing +[00:07.900 --> 00:08.950] damage, +[00:10.900 --> 00:11.700] President, +[00:14.200 --> 00:14.780] expansion, +[00:17.120 --> 00:17.760] hospital, +[00:20.730 --> 00:21.330] devastated. diff --git a/tests/expected/verbose/vad_silero3.0_words.wav.stdout b/tests/expected/verbose/vad_silero3.0_words.wav.stdout new file mode 100644 index 0000000..79e9f61 --- /dev/null +++ b/tests/expected/verbose/vad_silero3.0_words.wav.stdout @@ -0,0 +1,8 @@ +[00:00.760 --> 00:01.480] settlement, +[00:02.890 --> 00:03.670] Kentucky, +[00:05.710 --> 00:06.270] causing +[00:07.850 --> 00:08.930] damage, +[00:10.940 --> 00:11.700] president, +[00:14.200 --> 00:14.780] expansion, +[00:17.120 --> 00:17.780] hospital, +[00:20.140 --> 00:21.380] devastated. diff --git a/tests/expected/verbose/vad_silero3.1_words.wav.stdout b/tests/expected/verbose/vad_silero3.1_words.wav.stdout new file mode 100644 index 0000000..54b0fd0 --- /dev/null +++ b/tests/expected/verbose/vad_silero3.1_words.wav.stdout @@ -0,0 +1,8 @@ +[00:00.760 --> 00:01.480] settlement, +[00:02.920 --> 00:03.660] Kentucky, +[00:05.760 --> 00:06.260] causing +[00:07.850 --> 00:08.940] damage, +[00:10.840 --> 00:11.700] president, +[00:14.190 --> 00:14.770] expansion, +[00:17.130 --> 00:17.750] hospital, +[00:21.200 --> 00:21.380] devastated. diff --git a/tests/expected/verbose/vad_words.wav.stdout b/tests/expected/verbose/vad_words.wav.stdout index f585abe..5dbf58f 100644 --- a/tests/expected/verbose/vad_words.wav.stdout +++ b/tests/expected/verbose/vad_words.wav.stdout @@ -1,8 +1,8 @@ -[00:00.140 --> 00:01.320] Settlement. -[00:03.020 --> 00:03.600] Kentucky. -[00:05.170 --> 00:06.130] Causing. -[00:08.040 --> 00:08.940] Damage. -[00:10.890 --> 00:11.510] President. -[00:13.730 --> 00:14.790] Expansion. -[00:16.980 --> 00:17.600] Hospital. -[00:20.410 --> 00:21.430] Devastated. +[00:00.760 --> 00:01.460] settlement, +[00:02.900 --> 00:03.680] Kentucky, +[00:05.710 --> 00:06.270] causing +[00:07.890 --> 00:08.940] damage, +[00:10.930 --> 00:11.690] president, +[00:14.070 --> 00:14.770] expansion, +[00:17.140 --> 00:17.780] hospital, +[00:20.730 --> 00:21.370] devastated. diff --git a/tests/run_tests.py b/tests/run_tests.py index 4a9da66..0b202b6 100644 --- a/tests/run_tests.py +++ b/tests/run_tests.py @@ -38,7 +38,10 @@ "-c", "--catch", "-b", "--buffer", "-k", - ] and (i==0 or args[i-1] not in ["-k"]) and (arg.startswith("-") or (i>0 and args[i-1].startswith("-"))): + ] \ + and not arg.startswith("Test") \ + and (i==0 or args[i-1] not in ["-k"]) \ + and (arg.startswith("-") or (i>0 and args[i-1].startswith("-"))): test_transcribe.CMD_OPTIONS.append(arg) sys.argv.remove(arg) diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index 3580534..d16830e 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -489,15 +489,40 @@ def test_monolingual_small(self): class TestTranscribeWithVad(TestHelperCli): - def test_vad(self): + def test_vad_default(self): self._test_cli_( - ["--model", "large", "--accurate", "--language", "en", "--vad", "True", "--verbose", "True"], + ["--model", "tiny", "--accurate", "--language", "en", "--vad", "True", "--verbose", "True"], "verbose", files=["words.wav"], prefix="vad", extensions=None, ) + def test_vad_custom_silero(self): + self._test_cli_( + ["--model", "tiny", "--accurate", "--language", "en", "--vad", "silero:v3.1", "--verbose", "True"], + "verbose", + files=["words.wav"], + prefix="vad_silero3.1", + extensions=None, + ) + self._test_cli_( + ["--model", "tiny", "--accurate", "--language", "en", "--vad", "silero:v3.0", "--verbose", "True"], + "verbose", + files=["words.wav"], + prefix="vad_silero3.0", + extensions=None, + ) + + def test_vad_custom_auditok(self): + self._test_cli_( + ["--model", "tiny", "--language", "en", "--vad", "auditok", "--verbose", "True"], + "verbose", + files=["words.wav"], + prefix="vad_auditok", + extensions=None, + ) + class TestTranscribeUnspacedLanguage(TestHelperCli): diff --git a/whisper_timestamped/transcribe.py b/whisper_timestamped/transcribe.py index 7edb55a..e6fa7d3 100755 --- a/whisper_timestamped/transcribe.py +++ b/whisper_timestamped/transcribe.py @@ -3,7 +3,7 @@ __author__ = "Jérôme Louradour" __credits__ = ["Jérôme Louradour"] __license__ = "GPLv3" -__version__ = "1.13.4" +__version__ = "1.14.1" # Set some environment variables import os @@ -130,9 +130,12 @@ def transcribe_timestamped( Whether to compute word confidence. If True, a finer confidence for each segment will be computed as well. - vad: bool + vad: bool or str in ["silero", "silero:3.1", "auditok"] Whether to perform voice activity detection (VAD) on the audio file, to remove silent parts before transcribing with Whisper model. This should decrease hallucinations from the Whisper model. + When set to True, the default VAD algorithm is used (silero). + When set to a string, the corresponding VAD algorithm is used (silero, silero:3.1 or auditok). + Note that the library for the corresponding VAD algorithm must be installed. detect_disfluencies: bool Whether to detect disfluencies (i.e. hesitations, filler words, repetitions, corrections, etc.) that Whisper model might have omitted in the transcription. @@ -219,6 +222,7 @@ def transcribe_timestamped( naive_approach = True # Input options + vad = check_vad_method(vad) if isinstance(model, str): model = load_model(model) if fp16 is None: @@ -266,7 +270,7 @@ def transcribe_timestamped( if vad: audio = get_audio_tensor(audio) - audio, convert_timestamps = remove_non_speech(audio, plot=plot_word_alignment) + audio, convert_timestamps = remove_non_speech(audio, method=vad, plot=plot_word_alignment) global num_alignment_for_plot num_alignment_for_plot = 0 @@ -1773,12 +1777,43 @@ def split_tokens_on_spaces(tokens: torch.Tensor, tokenizer, remove_punctuation_f return words, word_tokens, word_tokens_indices -silero_vad_model = None +def check_vad_method(method, with_version=False): + if method in [True, "True", "true"]: + return check_vad_method("silero") # default method + elif method in [False, "False", "false"]: + return False + elif method.startswith("silero"): + version = None + if method != "silero": + assert method.startswith("silero:"), f"Got unexpected VAD method {method}" + version = method.split(":")[1] + if not version.startswith("v"): + version = "v" + version + try: + assert float(version[1:]) >= 1 + except: + raise ValueError(f"Got unexpected silero version {version} (please check https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models)") + if with_version: + return ("silero", version) + else: + return method + elif method == "auditok": + try: + import auditok + except ImportError: + raise ImportError("Please install auditok to use the auditok VAD (or use another VAD method)") + else: + raise ValueError(f"Got unexpected VAD method {method}") + return method + +_silero_vad_model = None +_has_onnx = None def get_vad_segments(audio, output_sample=False, min_speech_duration=0.1, min_silence_duration=0.1, dilatation=0.5, + method="silero", ): """ Get speech segments from audio using Silero VAD @@ -1793,29 +1828,108 @@ def get_vad_segments(audio, minimum duration (in sec) of a silence segment dilatation: float how much (in sec) to enlarge each speech segment detected by the VAD + method: str + VAD method to use (auditok, silero, silero:v3.1) """ - global silero_vad_model, silero_get_speech_ts - - if silero_vad_model is None: - import onnxruntime - onnxruntime.set_default_logger_severity(3) # Remove warning "Removing initializer 'XXX'. It is not used by any node and should be removed from the model." - repo_or_dir = os.path.expanduser("~/.cache/torch/hub/snakers4_silero-vad_master") - source = "local" - if not os.path.exists(repo_or_dir): - repo_or_dir = "snakers4/silero-vad" - source = "github" - silero_vad_model, utils = torch.hub.load(repo_or_dir=repo_or_dir, model="silero_vad", onnx=True, source=source) - - silero_get_speech_ts = utils[0] - - # Cheap normalization of the volume - audio = audio / max(0.1, audio.abs().max()) - - segments = silero_get_speech_ts(audio, silero_vad_model, - min_speech_duration_ms = round(min_speech_duration * 1000), - min_silence_duration_ms = round(min_silence_duration * 1000), - return_seconds = False, - ) + global _silero_vad_model, _silero_get_speech_ts, _has_onnx + + if method.startswith("silero"): + + version = None + _, version = check_vad_method(method, True) + # See discussion https://github.com/linto-ai/whisper-timestamped/pull/142/files#r1398326287 + need_folder_hack = version and (version < "v4") + + if _silero_vad_model is None: + # ONNX support since 3.1 in silero + if (version is None or version >= "v3.1") and (_has_onnx is not False): + onnx=True + try: + import onnxruntime + onnxruntime.set_default_logger_severity(3) # Remove warning "Removing initializer 'XXX'. It is not used by any node and should be removed from the model." + _has_onnx = True + except ImportError as err: + logger.warning(f"Please install onnxruntime to use more efficiently silero VAD") + _has_onnx = False + onnx=False + else: + onnx=False + + # Choose silero version because of problems with version 4, see https://github.com/linto-ai/whisper-timestamped/issues/74 + repo_or_dir_master = os.path.expanduser("~/.cache/torch/hub/snakers4_silero-vad_master") + repo_or_dir_specific = os.path.expanduser(f"~/.cache/torch/hub/snakers4_silero-vad_{version}") if version else repo_or_dir_master + repo_or_dir = repo_or_dir_specific + tmp_folder = None + def apply_folder_hack(): + nonlocal tmp_folder + if os.path.exists(repo_or_dir_master): + tmp_folder = repo_or_dir_master + ".tmp" + shutil.move(repo_or_dir_master, tmp_folder) + # Make a symlink to the v3.1 model, otherwise it fails + input_exists = os.path.exists(repo_or_dir_specific) + if not input_exists: + # Make dummy file for the symlink to work + os.makedirs(repo_or_dir_specific, exist_ok=True) + os.symlink(repo_or_dir_specific, repo_or_dir_master) + if not input_exists: + shutil.rmtree(repo_or_dir_specific) + + source = "local" + if not os.path.exists(repo_or_dir): + # Load specific version of silero + repo_or_dir = f"snakers4/silero-vad:{version}" if version else "snakers4/silero-vad" + source = "github" + if need_folder_hack: + apply_folder_hack() + try: + _silero_vad_model, utils = torch.hub.load(repo_or_dir=repo_or_dir, model="silero_vad", onnx=onnx, source=source) + except ImportError as err: + raise RuntimeError(f"Please install what is needed to use the silero VAD (or use another VAD method)") from err + except Exception as err: + raise RuntimeError(f"Problem when installing silero with version {version}. Check versions here: https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models") from err + finally: + if need_folder_hack: + if os.path.exists(repo_or_dir_master): + os.remove(repo_or_dir_master) + if tmp_folder: + shutil.move(tmp_folder, repo_or_dir_master) + assert os.path.isdir(repo_or_dir_specific), f"Unexpected situation: missing {repo_or_dir_specific}" + + _silero_get_speech_ts = utils[0] + + # Cheap normalization of the volume + audio = audio / max(0.1, audio.abs().max()) + + segments = _silero_get_speech_ts(audio, _silero_vad_model, + min_speech_duration_ms = round(min_speech_duration * 1000), + min_silence_duration_ms = round(min_silence_duration * 1000), + return_seconds = False, + ) + + elif method == "auditok": + import auditok + + # Cheap normalization of the volume + audio = audio / max(0.1, audio.abs().max()) + + data = (audio.numpy() * 32767).astype(np.int16).tobytes() + + segments = auditok.split( + data, + sampling_rate=SAMPLE_RATE, # sampling frequency in Hz + channels=1, # number of channels + sample_width=2, # number of bytes per sample + min_dur=min_speech_duration, # minimum duration of a valid audio event in seconds + max_dur=len(audio)/SAMPLE_RATE, # maximum duration of an event + max_silence=min_silence_duration, # maximum duration of tolerated continuous silence within an event + energy_threshold=50, + drop_trailing_silence=True, + ) + + segments = [{"start": s._meta.start * SAMPLE_RATE, "end": s._meta.end * SAMPLE_RATE} for s in segments] + + else: + raise ValueError(f"Got unexpected VAD method {method}") if dilatation > 0: dilatation = round(dilatation * SAMPLE_RATE) @@ -1847,12 +1961,28 @@ def remove_non_speech(audio, use_sample=False, min_speech_duration=0.1, min_silence_duration=1, + method="silero", plot=False, ): """ Remove non-speech segments from audio (using Silero VAD), glue the speech segments together and return the result along with a function to convert timestamps from the new audio to the original audio + + parameters: + audio: torch.Tensor + audio data *in 16kHz* + use_sample: bool + if True, return start and end in samples instead of seconds + min_speech_duration: float + minimum duration (in sec) of a speech segment + min_silence_duration: float + minimum duration (in sec) of a silence segment + method: str + method to use to remove non-speech segments + plot: bool or str + if True, plot the result. + If a string, save the plot to the given file """ segments = get_vad_segments( @@ -1860,6 +1990,7 @@ def remove_non_speech(audio, output_sample=True, min_speech_duration=min_speech_duration, min_silence_duration=min_silence_duration, + method=method, ) segments = [(seg["start"], seg["end"]) for seg in segments] @@ -2342,7 +2473,7 @@ def str2output_formats(string): parser.add_argument('--language', help=f"language spoken in the audio, specify None to perform language detection.", choices=sorted(whisper.tokenizer.LANGUAGES.keys()) + sorted([k.title() for k in whisper.tokenizer.TO_LANGUAGE_CODE.keys()]), default=None) # f"{', '.join(sorted(k+'('+v+')' for k,v in whisper.tokenizer.LANGUAGES.items()))} - parser.add_argument('--vad', default=False, help="whether to run Voice Activity Detection (VAD) to remove non-speech segment before applying Whisper model (removes hallucinations)", type=str2bool) + parser.add_argument('--vad', default=False, help="whether to run Voice Activity Detection (VAD) to remove non-speech segment before applying Whisper model (removes hallucinations). Can be: True, False, silero, silero:3.1 (or another version), or autitok. Some additional libraries might be needed") parser.add_argument('--detect_disfluencies', default=False, help="whether to try to detect disfluencies, marking them as special words [*]", type=str2bool) parser.add_argument('--recompute_all_timestamps', default=not TRUST_WHISPER_TIMESTAMP_BY_DEFAULT, help="Do not rely at all on Whisper timestamps (Experimental option: did not bring any improvement, but could be useful in cases where Whipser segment timestamp are wrong by more than 0.5 seconds)", type=str2bool) parser.add_argument("--punctuations_with_words", default=True, help="whether to include punctuations in the words", type=str2bool)