Skip to content

Commit

Permalink
silero: fix vad padding & choppy audio (#631)
Browse files Browse the repository at this point in the history
  • Loading branch information
theomonnom authored Aug 16, 2024
1 parent 6f534ae commit 7b611cd
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 36 deletions.
6 changes: 6 additions & 0 deletions .changeset/tidy-hairs-poke.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"livekit-agents": patch
"livekit-plugins-silero": patch
---

silero: fix vad padding & static audio
6 changes: 5 additions & 1 deletion livekit-agents/livekit/agents/vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ class VADEvent:
silence_duration: float
"""duration of the silence in seconds"""
frames: List[rtc.AudioFrame] = field(default_factory=list)
"""list of audio frames of the speech"""
"""list of audio frames of the speech
start_of_speech: contains the complete audio chunks that triggered the detection)
end_of_speech: contains the complete user speech
"""
probability: float = 0.0
"""smoothed probability of the speech (only for INFERENCE_DONE event)"""
inference_duration: float = 0.0
Expand Down
102 changes: 67 additions & 35 deletions livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from . import onnx_model
from .log import logger

SLOW_INFERENCE_THRESHOLD = 0.2 # late by 200ms


@dataclass
class _VADOptions:
Expand Down Expand Up @@ -108,11 +110,14 @@ def __init__(self, opts: _VADOptions, model: onnx_model.OnnxModel) -> None:
self._task.add_done_callback(lambda _: self._executor.shutdown(wait=False))
self._exp_filter = utils.ExpFilter(alpha=0.35)

self._extra_inference_time = 0.0

@agents.utils.log_exceptions(logger=logger)
async def _main_task(self):
og_sample_rate = 0
og_needed_samples = 0 # needed samples to complete the window data
og_window_size_samples = 0 # size in samples of og_window_data
og_padding_size_samples = 0 # size in samples of padding data
og_window_data: np.ndarray | None = None

index_step = 0
Expand Down Expand Up @@ -143,28 +148,38 @@ async def _main_task(self):
elif og_window_data is None:
# alloc the og buffers now that we know the pushed sample rate
og_sample_rate = frame.sample_rate

og_window_size_samples = int(
(self._model.window_size_samples / self._model.sample_rate)
* og_sample_rate
)
og_padding_size_samples = int(
self._opts.padding_duration * og_sample_rate
)
og_window_data = np.empty(og_window_size_samples, dtype=np.int16)
og_needed_samples = og_window_size_samples
index_step = frame.sample_rate // 16000

speech_buffer = np.empty(
int(self._opts.max_buffered_speech * og_sample_rate), dtype=np.int16
int(self._opts.max_buffered_speech * og_sample_rate)
+ int(self._opts.padding_duration * og_sample_rate) * 2,
dtype=np.int16,
)
elif og_sample_rate != frame.sample_rate:
logger.error("a frame with another sample rate was already pushed")
continue

frame_data = np.frombuffer(frame.data, dtype=np.int16)
remaining_samples = len(frame_data)

while remaining_samples > 0:
to_copy = min(remaining_samples, og_needed_samples)

index = len(og_window_data) - og_needed_samples
og_window_data[index : index + to_copy] = frame_data[:to_copy]
window_index = og_window_size_samples - og_needed_samples
frame_index = len(frame_data) - remaining_samples
og_window_data[window_index : window_index + to_copy] = frame_data[
frame_index : frame_index + to_copy
]

remaining_samples -= to_copy
og_needed_samples -= to_copy
Expand All @@ -183,45 +198,74 @@ async def _main_task(self):
)

# run the inference
start_time = time.time()
start_time = time.perf_counter()
raw_prob = await self._loop.run_in_executor(
self._executor, self._model, inference_window_data
)

inference_duration = time.perf_counter() - start_time

prob_change = abs(raw_prob - self._exp_filter.filtered())
exp = 0.5 if prob_change > 0.25 else 1
raw_prob = self._exp_filter.apply(exp=exp, sample=raw_prob)

inference_duration = time.time() - start_time
window_duration = (
self._model.window_size_samples / self._opts.sample_rate
)
if inference_duration > window_duration:

self._extra_inference_time = max(
0.0,
self._extra_inference_time + inference_duration - window_duration,
)
if inference_duration > SLOW_INFERENCE_THRESHOLD:
logger.warning(
"vad inference took too long - slower than realtime: %f",
inference_duration,
"inference is slower than realtime",
extra={"delay": self._extra_inference_time},
)

pub_current_sample += og_window_size_samples

def _copy_window():
def _copy_inference_window():
nonlocal speech_buffer_index
to_copy = min(
og_window_size_samples,
len(speech_buffer) - speech_buffer_index,
)
available_space = len(speech_buffer) - speech_buffer_index
to_copy = min(og_window_size_samples, available_space)
if to_copy <= 0:
# max_buffered_speech reached
return
return # max_buffered_speech reached

speech_buffer[
speech_buffer_index : speech_buffer_index + to_copy
] = og_window_data
speech_buffer_index += og_window_size_samples
] = og_window_data[:to_copy]
speech_buffer_index += to_copy

def _reset_write_cursor():
nonlocal speech_buffer_index
if speech_buffer_index <= og_padding_size_samples:
return

padding_data = speech_buffer[
speech_buffer_index
- og_padding_size_samples : speech_buffer_index
]

speech_buffer[:og_padding_size_samples] = padding_data
speech_buffer_index = og_padding_size_samples

def _copy_speech_buffer() -> rtc.AudioFrame:
# copy the data from speech_buffer
assert speech_buffer is not None
speech_data = speech_buffer[:speech_buffer_index].tobytes()

return rtc.AudioFrame(
sample_rate=og_sample_rate,
num_channels=1,
samples_per_channel=speech_buffer_index,
data=speech_data,
)

_copy_inference_window()

if pub_speaking:
pub_speech_duration += window_duration
_copy_window()
else:
pub_silence_duration += window_duration

Expand All @@ -242,8 +286,6 @@ def _copy_window():
silence_threshold_duration = 0.0

if not pub_speaking:
_copy_window()

if speech_threshold_duration >= self._opts.min_speech_duration:
pub_speaking = True
pub_silence_duration = 0.0
Expand All @@ -255,6 +297,7 @@ def _copy_window():
samples_index=pub_current_sample,
silence_duration=pub_silence_duration,
speech_duration=pub_speech_duration,
frames=[_copy_speech_buffer()],
speaking=True,
)
)
Expand All @@ -263,37 +306,26 @@ def _copy_window():
speech_threshold_duration = 0.0

if not pub_speaking:
speech_buffer_index = 0
_reset_write_cursor()

if (
pub_speaking
and silence_threshold_duration
>= self._opts.min_silence_duration
>= self._opts.min_silence_duration + self._opts.padding_duration
):
pub_speaking = False
pub_speech_duration = 0.0
pub_silence_duration = silence_threshold_duration

speech_data = speech_buffer[
:speech_buffer_index
].tobytes() # copy the data from speech_buffer

self._event_ch.send_nowait(
agents.vad.VADEvent(
type=agents.vad.VADEventType.END_OF_SPEECH,
samples_index=pub_current_sample,
silence_duration=pub_silence_duration,
speech_duration=pub_speech_duration,
frames=[
rtc.AudioFrame(
sample_rate=og_sample_rate,
num_channels=1,
samples_per_channel=speech_buffer_index,
data=speech_data,
)
],
frames=[_copy_speech_buffer()],
speaking=False,
)
)

speech_buffer_index = 0
_reset_write_cursor()
1 change: 1 addition & 0 deletions tests/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
**/test_vad*.wav
1 change: 1 addition & 0 deletions tests/test_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ async def _pong():
msg = await ipc.channel.arecv_message(cch, IPC_MESSAGES)
await ipc.channel.asend_message(cch, msg)
except utils.aio.duplex_unix.DuplexClosed:
print("_echo_main, duplex closed..")
break

asyncio.run(_pong())
Expand Down
66 changes: 66 additions & 0 deletions tests/test_vad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from livekit.agents import vad
from livekit.plugins import silero

from . import utils

VAD = silero.VAD.load(
min_speech_duration=0.5, min_silence_duration=0.5, padding_duration=1.0
)


async def test_chunks_vad() -> None:
frames, transcript = utils.make_test_audio(chunk_duration_ms=10)
assert len(frames) > 1, "frames aren't chunked"

stream = VAD.stream()

for frame in frames:
stream.push_frame(frame)

stream.end_input()

start_of_speech_i = 0
end_of_speech_i = 0
async for ev in stream:
if ev.type == vad.VADEventType.START_OF_SPEECH:
with open(
f"test_vad.start_of_speech_frames_{start_of_speech_i}.wav", "wb"
) as f:
f.write(utils.make_wav_file(ev.frames))

start_of_speech_i += 1

if ev.type == vad.VADEventType.END_OF_SPEECH:
with open(
f"test_vad.end_of_speech_frames_{end_of_speech_i}.wav", "wb"
) as f:
f.write(utils.make_wav_file(ev.frames))

end_of_speech_i += 1

assert start_of_speech_i > 0, "no start of speech detected"
assert start_of_speech_i == end_of_speech_i, "start and end of speech mismatch"


async def test_file_vad():
frames, transcript = utils.make_test_audio()
assert len(frames) == 1, "one frame should be the whole audio"

stream = VAD.stream()

for frame in frames:
stream.push_frame(frame)

stream.end_input()

start_of_speech_i = 0
end_of_speech_i = 0
async for ev in stream:
if ev.type == vad.VADEventType.START_OF_SPEECH:
start_of_speech_i += 1

if ev.type == vad.VADEventType.END_OF_SPEECH:
end_of_speech_i += 1

assert start_of_speech_i > 0, "no start of speech detected"
assert start_of_speech_i == end_of_speech_i, "start and end of speech mismatch"
60 changes: 60 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
from __future__ import annotations

import io
import os
import pathlib
import wave

import jiwer as tr
from livekit import rtc
from livekit.agents import utils

TEST_AUDIO_FILEPATH = os.path.join(os.path.dirname(__file__), "long.mp3")
TEST_AUDIO_TRANSCRIPT = pathlib.Path(
os.path.dirname(__file__), "long_transcript.txt"
).read_text()


def wer(hypothesis: str, reference: str) -> float:
Expand All @@ -21,3 +35,49 @@ def wer(hypothesis: str, reference: str) -> float:
reference_transform=wer_standardize_contiguous,
hypothesis_transform=wer_standardize_contiguous,
)


def read_mp3_file(path) -> rtc.AudioFrame:
mp3 = utils.codecs.Mp3StreamDecoder()
frames: list[rtc.AudioFrame] = []
with open(path, "rb") as file:
while True:
chunk = file.read(4096)
if not chunk:
break

frames.extend(mp3.decode_chunk(chunk))

return utils.merge_frames(frames) # merging just for ease of use


def make_test_audio(
chunk_duration_ms: int | None = None,
) -> (list[rtc.AudioFrame], str):
mp3_audio = read_mp3_file(TEST_AUDIO_FILEPATH)

if not chunk_duration_ms:
return [mp3_audio], TEST_AUDIO_TRANSCRIPT

chunk_size = int(mp3_audio.sample_rate / (1000 / chunk_duration_ms))
bstream = utils.audio.AudioByteStream(
sample_rate=mp3_audio.sample_rate,
num_channels=mp3_audio.num_channels,
samples_per_channel=chunk_size,
)

frames = bstream.write(mp3_audio.data.tobytes())
frames.extend(bstream.flush())
return frames, TEST_AUDIO_TRANSCRIPT


def make_wav_file(frames: list[rtc.AudioFrame]) -> bytes:
buffer = utils.merge_frames(frames)
io_buffer = io.BytesIO()
with wave.open(io_buffer, "wb") as wav:
wav.setnchannels(buffer.num_channels)
wav.setsampwidth(2) # 16-bit
wav.setframerate(buffer.sample_rate)
wav.writeframes(buffer.data)

return io_buffer.getvalue()

0 comments on commit 7b611cd

Please sign in to comment.