Skip to content

Commit

Permalink
version 1.14.0 : can choose among several VAD methods (different vers…
Browse files Browse the repository at this point in the history
…ions of silero, and auditok)
  • Loading branch information
Jeronymous committed Nov 30, 2023
1 parent 52100d2 commit 6744db6
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 47 deletions.
5 changes: 4 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ COPY setup.py /usr/src/app/setup.py
COPY whisper_timestamped /usr/src/app/whisper_timestamped

# Install
RUN cd /usr/src/app/ && pip3 install ".[dev]" && pip3 install ".[vad]"
RUN cd /usr/src/app/ && pip3 install ".[dev]"
RUN cd /usr/src/app/ && pip3 install ".[vad_silero]"
RUN cd /usr/src/app/ && pip3 install ".[vad_auditok]"
RUN cd /usr/src/app/ && pip3 install ".[test]"

# Cleanup
RUN rm -R /usr/src/app/requirements.txt /usr/src/app/setup.py /usr/src/app/whisper_timestamped
Expand Down
27 changes: 25 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ The approach is based on Dynamic Time Warping (DTW) applied to cross-attention w

`whisper-timestamped` is an extension of the [`openai-whisper`](https://pypi.org/project/whisper-openai/) Python package and is meant to be compatible with any version of `openai-whisper`.
It provides more efficient/accurate word timestamps, along with those additional features:
* Voice Activity Detection (VAD) can be run before applying Whisper model, to avoid hallucinations due to errors in the training data (for instance, predicting "Thanks you for watching!" on pure silence).
* Voice Activity Detection (VAD) can be run before applying Whisper model,
to avoid hallucinations due to errors in the training data (for instance, predicting "Thanks you for watching!" on pure silence).
Several VAD methods are available: silero (default), auditok, auditok:v3.1
* When the language is not specified, the language probabilities are provided among the outputs.

### Notes on other approaches
Expand All @@ -55,7 +57,7 @@ Requirements:

You can install `whisper-timestamped` either by using pip:
```bash
pip3 install git+https://github.com/linto-ai/whisper-timestamped
pip3 install whisper-timestamped
```

or by cloning this repository and running installation:
Expand Down Expand Up @@ -327,6 +329,27 @@ results = whisper_timestamped.transcribe(model, audio, vad=True, ...)
whisper_timestamped --vad True ...
```

By default, the VAD method used is [silero](https://github.com/snakers4/silero-vad).
But other methods are available, such as earlier versions of silero, or [auditok](https://github.com/amsehili/auditok).
Those methods were introduced because latest versions of silero VAD can have a lot of false alarms on some audios (speech detected on silence).
* In Python:
```python
results = whisper_timestamped.transcribe(model, audio, vad="silero:v3.1", ...)
results = whisper_timestamped.transcribe(model, audio, vad="auditok", ...)
```
* On the command line:
```bash
whisper_timestamped --vad silero:v3.1 ...
whisper_timestamped --vad auditok ...
```

In order to watch the VAD results, you can use the `--plot` option of the `whisper_timestamped` CLI,
or the `plot_word_alignment` option of the `whisper_timestamped.transcribe()` Python function.
It will show the VAD results as following:
| **silero:v4.0** | **silero:v3.1** | **auditok** |
| :---: | :---: | :---: |
| ![Example VAD](figs/VAD_silero_v4.0.png) | ![Example VAD](figs/VAD_silero_v3.1.png) | ![Example VAD](figs/VAD_auditok.png) |

#### Detecting disfluencies

Whisper models tend to remove speech disfluencies (filler words, hesitations, repetitions, etc.). Without precautions, the disfluencies that are not transcribed will affect the timestamp of the following word: the timestamp of the beginning of the word will actually be the timestamp of the beginning of the disfluencies. `whisper-timestamped` can have some heuristics to avoid this.
Expand Down
Binary file added figs/VAD_auditok.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figs/VAD_silero_v3.1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figs/VAD_silero_v4.0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 5 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
name="whisper-timestamped",
py_modules=["whisper_timestamped"],
version=version,
description="Add to OpenAI Whisper the capability to give word timestamps",
description="OpenAI Whisper ASR with accurate word timestamps, language detection confidence, several options of VAD, and more.",
python_requires=">=3.7",
author="Jeronymous",
url="https://github.com/linto-ai/whisper-timestamped",
Expand All @@ -37,7 +37,9 @@
},
include_package_data=True,
extras_require={
'dev': ['matplotlib', 'jsonschema', 'transformers'],
'vad': ['onnxruntime', 'torchaudio'],
'dev': ['matplotlib', 'transformers'],
'vad_silero': ['onnxruntime', 'torchaudio'],
'vad_auditok': ['auditok'],
'test': ['jsonschema'],
},
)
187 changes: 146 additions & 41 deletions whisper_timestamped/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
__author__ = "Jérôme Louradour"
__credits__ = ["Jérôme Louradour"]
__license__ = "GPLv3"
__version__ = "1.13.4"
__version__ = "1.14.0"

# Set some environment variables
import os
Expand Down Expand Up @@ -130,9 +130,12 @@ def transcribe_timestamped(
Whether to compute word confidence.
If True, a finer confidence for each segment will be computed as well.
vad: bool
vad: bool or str in ["silero", "silero:3.1", "auditok"]
Whether to perform voice activity detection (VAD) on the audio file, to remove silent parts before transcribing with Whisper model.
This should decrease hallucinations from the Whisper model.
When set to True, the default VAD algorithm is used (silero).
When set to a string, the corresponding VAD algorithm is used (silero, silero:3.1 or auditok).
Note that the library for the corresponding VAD algorithm must be installed.
detect_disfluencies: bool
Whether to detect disfluencies (i.e. hesitations, filler words, repetitions, corrections, etc.) that Whisper model might have omitted in the transcription.
Expand Down Expand Up @@ -219,6 +222,7 @@ def transcribe_timestamped(
naive_approach = True

# Input options
vad = check_vad_method(vad)
if isinstance(model, str):
model = load_model(model)
if fp16 is None:
Expand Down Expand Up @@ -266,7 +270,7 @@ def transcribe_timestamped(

if vad:
audio = get_audio_tensor(audio)
audio, convert_timestamps = remove_non_speech(audio, plot=plot_word_alignment)
audio, convert_timestamps = remove_non_speech(audio, method=vad, plot=plot_word_alignment)

global num_alignment_for_plot
num_alignment_for_plot = 0
Expand Down Expand Up @@ -1773,12 +1777,43 @@ def split_tokens_on_spaces(tokens: torch.Tensor, tokenizer, remove_punctuation_f

return words, word_tokens, word_tokens_indices

silero_vad_model = None
def check_vad_method(method, with_version=False):
if method in [True, "True", "true"]:
return check_vad_method("silero") # default method
elif method in [False, "False", "false"]:
return False
elif method.startswith("silero"):
version = None
if method != "silero":
assert method.startswith("silero:"), f"Got unexpected VAD method {method}"
version = method.split(":")[1]
if not version.startswith("v"):
version = "v" + version
try:
assert float(version[1:]) >= 1
except:
raise ValueError(f"Got unexpected silero version {version} (please check https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models)")
if with_version:
return ("silero", version)
else:
return method
elif method == "auditok":
try:
import auditok
except ImportError:
raise ImportError("Please install auditok to use the auditok VAD (or use another VAD method)")
else:
raise ValueError(f"Got unexpected VAD method {method}")
return method

_silero_vad_model = None
_has_onnx = None
def get_vad_segments(audio,
output_sample=False,
min_speech_duration=0.1,
min_silence_duration=0.1,
dilatation=0.5,
method="silero",
):
"""
Get speech segments from audio using Silero VAD
Expand All @@ -1793,43 +1828,96 @@ def get_vad_segments(audio,
minimum duration (in sec) of a silence segment
dilatation: float
how much (in sec) to enlarge each speech segment detected by the VAD
method: str
VAD method to use (auditok, silero, silero:v3.1)
"""
global silero_vad_model, silero_get_speech_ts

if silero_vad_model is None:
import onnxruntime
onnxruntime.set_default_logger_severity(3) # Remove warning "Removing initializer 'XXX'. It is not used by any node and should be removed from the model."
repo_or_dir_master = os.path.expanduser("~/.cache/torch/hub/snakers4_silero-vad_master")
repo_or_dir_v31 = os.path.expanduser("~/.cache/torch/hub/snakers4_silero-vad_v3.1")
repo_or_dir = repo_or_dir_v31
source = "local"
tmp_folder = None
if os.path.exists(repo_or_dir_master):
tmp_folder = repo_or_dir_master + ".tmp"
shutil.move(repo_or_dir_master, tmp_folder)
# Make a symlink to the v3.1 model, otherwise it fails
os.symlink(repo_or_dir_v31, repo_or_dir_master)
if not os.path.exists(repo_or_dir):
# Load version 3.1 from 17/12/2021 -- see https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models
# because of problems with version 4, see https://github.com/linto-ai/whisper-timestamped/issues/74
repo_or_dir = "snakers4/silero-vad:v3.1"
source = "github"
silero_vad_model, utils = torch.hub.load(repo_or_dir=repo_or_dir, model="silero_vad", onnx=True, source=source)
os.remove(repo_or_dir_master)
if tmp_folder:
shutil.move(tmp_folder, repo_or_dir_master)
assert os.path.isdir(repo_or_dir_v31)

silero_get_speech_ts = utils[0]

# Cheap normalization of the volume
audio = audio / max(0.1, audio.abs().max())

segments = silero_get_speech_ts(audio, silero_vad_model,
min_speech_duration_ms = round(min_speech_duration * 1000),
min_silence_duration_ms = round(min_silence_duration * 1000),
return_seconds = False,
)
global _silero_vad_model, _silero_get_speech_ts, _has_onnx

if method.startswith("silero"):

version = None
_, version = check_vad_method(method, True)
# See discussion https://github.com/linto-ai/whisper-timestamped/pull/142/files#r1398326287
need_folder_hack = version and (version < "v4")

if _silero_vad_model is None:
# ONNX support since 3.1 in silero
if (version is None or version >= "v3.1") and (_has_onnx is not False):
onnx=True
try:
import onnxruntime
onnxruntime.set_default_logger_severity(3) # Remove warning "Removing initializer 'XXX'. It is not used by any node and should be removed from the model."
_has_onnx = True
except ImportError as err:
logger.warning(f"Please install onnxruntime to use more efficiently silero VAD")
_has_onnx = False
onnx=False
else:
onnx=False
repo_or_dir_master = os.path.expanduser("~/.cache/torch/hub/snakers4_silero-vad_master")
repo_or_dir_specific = os.path.expanduser(f"~/.cache/torch/hub/snakers4_silero-vad_{version}") if version else repo_or_dir_master
repo_or_dir = repo_or_dir_specific
source = "local"
tmp_folder = None
if need_folder_hack:
if os.path.exists(repo_or_dir_master):
tmp_folder = repo_or_dir_master + ".tmp"
shutil.move(repo_or_dir_master, tmp_folder)
# Make a symlink to the v3.1 model, otherwise it fails
os.symlink(repo_or_dir_specific, repo_or_dir_master)
if not os.path.exists(repo_or_dir):
# Load version 3.1 from 17/12/2021 -- see https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models
# because of problems with version 4, see https://github.com/linto-ai/whisper-timestamped/issues/74
repo_or_dir = f"snakers4/silero-vad:{version}" if version else "snakers4/silero-vad"
source = "github"
try:
_silero_vad_model, utils = torch.hub.load(repo_or_dir=repo_or_dir, model="silero_vad", onnx=onnx, source=source)
except ImportError as err:
raise RuntimeError(f"Please install what is needed to use the silero VAD (or use another VAD method)") from err
except Exception as err:
raise RuntimeError(f"Problem when installing silero with version {version}. Check versions here: https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models") from err
finally:
if need_folder_hack:
os.remove(repo_or_dir_master)
if tmp_folder:
shutil.move(tmp_folder, repo_or_dir_master)
assert os.path.isdir(repo_or_dir_specific), f"Unexpected situation: missing {repo_or_dir_specific}"

_silero_get_speech_ts = utils[0]

# Cheap normalization of the volume
audio = audio / max(0.1, audio.abs().max())

segments = _silero_get_speech_ts(audio, _silero_vad_model,
min_speech_duration_ms = round(min_speech_duration * 1000),
min_silence_duration_ms = round(min_silence_duration * 1000),
return_seconds = False,
)

elif method == "auditok":
import auditok

# Cheap normalization of the volume
audio = audio / max(0.1, audio.abs().max())

data = (audio.numpy() * 32767).astype(np.int16).tobytes()

segments = auditok.split(
data,
sampling_rate=SAMPLE_RATE, # sampling frequency in Hz
channels=1, # number of channels
sample_width=2, # number of bytes per sample
min_dur=min_speech_duration, # minimum duration of a valid audio event in seconds
max_dur=len(audio)/SAMPLE_RATE, # maximum duration of an event
max_silence=min_silence_duration, # maximum duration of tolerated continuous silence within an event
energy_threshold=50,
drop_trailing_silence=True,
)

segments = [{"start": s._meta.start * SAMPLE_RATE, "end": s._meta.end * SAMPLE_RATE} for s in segments]

else:
raise ValueError(f"Got unexpected VAD method {method}")

if dilatation > 0:
dilatation = round(dilatation * SAMPLE_RATE)
Expand Down Expand Up @@ -1861,19 +1949,36 @@ def remove_non_speech(audio,
use_sample=False,
min_speech_duration=0.1,
min_silence_duration=1,
method="silero",
plot=False,
):
"""
Remove non-speech segments from audio (using Silero VAD),
glue the speech segments together and return the result along with
a function to convert timestamps from the new audio to the original audio
parameters:
audio: torch.Tensor
audio data *in 16kHz*
use_sample: bool
if True, return start and end in samples instead of seconds
min_speech_duration: float
minimum duration (in sec) of a speech segment
min_silence_duration: float
minimum duration (in sec) of a silence segment
method: str
method to use to remove non-speech segments
plot: bool or str
if True, plot the result.
If a string, save the plot to the given file
"""

segments = get_vad_segments(
audio,
output_sample=True,
min_speech_duration=min_speech_duration,
min_silence_duration=min_silence_duration,
method=method,
)

segments = [(seg["start"], seg["end"]) for seg in segments]
Expand Down Expand Up @@ -2356,7 +2461,7 @@ def str2output_formats(string):
parser.add_argument('--language', help=f"language spoken in the audio, specify None to perform language detection.", choices=sorted(whisper.tokenizer.LANGUAGES.keys()) + sorted([k.title() for k in whisper.tokenizer.TO_LANGUAGE_CODE.keys()]), default=None)
# f"{', '.join(sorted(k+'('+v+')' for k,v in whisper.tokenizer.LANGUAGES.items()))}

parser.add_argument('--vad', default=False, help="whether to run Voice Activity Detection (VAD) to remove non-speech segment before applying Whisper model (removes hallucinations)", type=str2bool)
parser.add_argument('--vad', default=False, help="whether to run Voice Activity Detection (VAD) to remove non-speech segment before applying Whisper model (removes hallucinations). Can be: True, False, silero, silero:3.1 (or another version), or autitok. Some additional libraries might be needed")
parser.add_argument('--detect_disfluencies', default=False, help="whether to try to detect disfluencies, marking them as special words [*]", type=str2bool)
parser.add_argument('--recompute_all_timestamps', default=not TRUST_WHISPER_TIMESTAMP_BY_DEFAULT, help="Do not rely at all on Whisper timestamps (Experimental option: did not bring any improvement, but could be useful in cases where Whipser segment timestamp are wrong by more than 0.5 seconds)", type=str2bool)
parser.add_argument("--punctuations_with_words", default=True, help="whether to include punctuations in the words", type=str2bool)
Expand Down

0 comments on commit 6744db6

Please sign in to comment.