From f0ace909fc36ccc06bd8b253d3f8450f14814100 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Monnom?= Date: Fri, 27 Sep 2024 20:55:34 -0700 Subject: [PATCH] silero: support any sample rate (#805) --- .changeset/rare-cows-smile.md | 6 + .changeset/warm-needles-change.md | 5 + livekit-agents/livekit/agents/tts/tts.py | 4 +- .../livekit/agents/utils/__init__.py | 4 +- livekit-agents/livekit/agents/utils/audio.py | 147 ++++++++- livekit-agents/livekit/agents/utils/misc.py | 40 --- livekit-agents/livekit/agents/vad.py | 38 ++- .../livekit/plugins/openai/llm.py | 2 +- .../livekit/plugins/silero/vad.py | 281 ++++++++++++------ tests/test_vad.py | 9 + 10 files changed, 386 insertions(+), 150 deletions(-) create mode 100644 .changeset/rare-cows-smile.md create mode 100644 .changeset/warm-needles-change.md diff --git a/.changeset/rare-cows-smile.md b/.changeset/rare-cows-smile.md new file mode 100644 index 000000000..6e733325b --- /dev/null +++ b/.changeset/rare-cows-smile.md @@ -0,0 +1,6 @@ +--- +"livekit-agents": patch +"livekit-plugins-silero": minor +--- + +silero: support any sample rate diff --git a/.changeset/warm-needles-change.md b/.changeset/warm-needles-change.md new file mode 100644 index 000000000..501a3862a --- /dev/null +++ b/.changeset/warm-needles-change.md @@ -0,0 +1,5 @@ +--- +"livekit-plugins-silero": patch +--- + +silero: add prefix_padding_duration #801 diff --git a/livekit-agents/livekit/agents/tts/tts.py b/livekit-agents/livekit/agents/tts/tts.py index ab4d1a8a0..3faf79680 100644 --- a/livekit-agents/livekit/agents/tts/tts.py +++ b/livekit-agents/livekit/agents/tts/tts.py @@ -7,7 +7,7 @@ from livekit import rtc -from ..utils import aio, misc +from ..utils import aio, audio @dataclass @@ -71,7 +71,7 @@ async def collect(self) -> rtc.AudioFrame: frames = [] async for ev in self: frames.append(ev.frame) - return misc.merge_frames(frames) + return audio.merge_frames(frames) @abstractmethod async def _main_task(self) -> None: ... diff --git a/livekit-agents/livekit/agents/utils/__init__.py b/livekit-agents/livekit/agents/utils/__init__.py index 361308572..739f5d508 100644 --- a/livekit-agents/livekit/agents/utils/__init__.py +++ b/livekit-agents/livekit/agents/utils/__init__.py @@ -1,13 +1,15 @@ from . import aio, audio, codecs, http_context, images +from .audio import AudioBuffer, combine_frames, merge_frames from .event_emitter import EventEmitter from .exp_filter import ExpFilter from .log import log_exceptions -from .misc import AudioBuffer, merge_frames, shortuuid, time_ms +from .misc import shortuuid, time_ms from .moving_average import MovingAverage __all__ = [ "AudioBuffer", "merge_frames", + "combine_frames", "time_ms", "shortuuid", "http_context", diff --git a/livekit-agents/livekit/agents/utils/audio.py b/livekit-agents/livekit/agents/utils/audio.py index a8cbf1c17..33ab8571f 100644 --- a/livekit-agents/livekit/agents/utils/audio.py +++ b/livekit-agents/livekit/agents/utils/audio.py @@ -1,19 +1,129 @@ from __future__ import annotations import ctypes +from typing import List, Union from livekit import rtc from ..log import logger +AudioBuffer = Union[List[rtc.AudioFrame], rtc.AudioFrame] + + +def combine_frames(buffer: AudioBuffer) -> rtc.AudioFrame: + """ + Combines one or more `rtc.AudioFrame` objects into a single `rtc.AudioFrame`. + + This function concatenates the audio data from multiple frames, ensuring that + all frames have the same sample rate and number of channels. It efficiently + merges the data by preallocating the necessary memory and copying the frame + data without unnecessary reallocations. + + Args: + buffer (AudioBuffer): A single `rtc.AudioFrame` or a list of `rtc.AudioFrame` + objects to be combined. + + Returns: + rtc.AudioFrame: A new `rtc.AudioFrame` containing the combined audio data. + + Raises: + ValueError: If the buffer is empty. + ValueError: If frames have differing sample rates. + ValueError: If frames have differing numbers of channels. + + Example: + >>> frame1 = rtc.AudioFrame( + ... data=b"\x01\x02", sample_rate=48000, num_channels=2, samples_per_channel=1 + ... ) + >>> frame2 = rtc.AudioFrame( + ... data=b"\x03\x04", sample_rate=48000, num_channels=2, samples_per_channel=1 + ... ) + >>> combined_frame = combine_frames([frame1, frame2]) + >>> combined_frame.data + b'\x01\x02\x03\x04' + >>> combined_frame.sample_rate + 48000 + >>> combined_frame.num_channels + 2 + >>> combined_frame.samples_per_channel + 2 + """ + if not isinstance(buffer, list): + return buffer + + if not buffer: + raise ValueError("buffer is empty") + + sample_rate = buffer[0].sample_rate + num_channels = buffer[0].num_channels + + total_data_length = 0 + total_samples_per_channel = 0 + + for frame in buffer: + if frame.sample_rate != sample_rate: + raise ValueError( + f"Sample rate mismatch: expected {sample_rate}, got {frame.sample_rate}" + ) + + if frame.num_channels != num_channels: + raise ValueError( + f"Channel count mismatch: expected {num_channels}, got {frame.num_channels}" + ) + + total_data_length += len(frame.data) + total_samples_per_channel += frame.samples_per_channel + + data = bytearray(total_data_length) + offset = 0 + for frame in buffer: + frame_data = frame.data.cast("b") + data[offset : offset + len(frame_data)] = frame_data + offset += len(frame_data) + + return rtc.AudioFrame( + data=data, + sample_rate=sample_rate, + num_channels=num_channels, + samples_per_channel=total_samples_per_channel, + ) + + +merge_frames = combine_frames + class AudioByteStream: + """ + Buffer and chunk audio byte data into fixed-size frames. + + This class is designed to handle incoming audio data in bytes, + buffering it and producing audio frames of a consistent size. + It is mainly used to easily chunk big or too small audio frames + into a fixed size, helping to avoid processing very small frames + (which can be inefficient) and very large frames (which can cause + latency or processing delays). By normalizing frame sizes, it + facilitates consistent and efficient audio data processing. + """ + def __init__( self, sample_rate: int, num_channels: int, samples_per_channel: int | None = None, ) -> None: + """ + Initialize an AudioByteStream instance. + + Parameters: + sample_rate (int): The audio sample rate in Hz. + num_channels (int): The number of audio channels. + samples_per_channel (int, optional): The number of samples per channel in each frame. + If None, defaults to `sample_rate // 10` (i.e., 100ms of audio data). + + The constructor sets up the internal buffer and calculates the size of each frame in bytes. + The frame size is determined by the number of channels, samples per channel, and the size + of each sample (assumed to be 16 bits or 2 bytes). + """ self._sample_rate = sample_rate self._num_channels = num_channels @@ -25,7 +135,25 @@ def __init__( ) self._buf = bytearray() - def write(self, data: bytes) -> list[rtc.AudioFrame]: + def push(self, data: bytes) -> list[rtc.AudioFrame]: + """ + Add audio data to the buffer and retrieve fixed-size frames. + + Parameters: + data (bytes): The incoming audio data to buffer. + + Returns: + list[rtc.AudioFrame]: A list of `AudioFrame` objects of fixed size. + + The method appends the incoming data to the internal buffer. + While the buffer contains enough data to form complete frames, + it extracts the data for each frame, creates an `AudioFrame` object, + and appends it to the list of frames to return. + + This allows you to feed in variable-sized chunks of audio data + (e.g., from a stream or file) and receive back a list of + fixed-size audio frames ready for processing or transmission. + """ self._buf.extend(data) frames = [] @@ -44,7 +172,24 @@ def write(self, data: bytes) -> list[rtc.AudioFrame]: return frames + write = push # Alias for the push method. + def flush(self) -> list[rtc.AudioFrame]: + """ + Flush the buffer and retrieve any remaining audio data as a frame. + + Returns: + list[rtc.AudioFrame]: A list containing any remaining `AudioFrame` objects. + + This method processes any remaining data in the buffer that does not + fill a complete frame. If the remaining data forms a partial frame + (i.e., its size is not a multiple of the expected sample size), a warning is + logged and an empty list is returned. Otherwise, it returns the final + `AudioFrame` containing the remaining data. + + Use this method when you have no more data to push and want to ensure + that all buffered audio data has been processed. + """ if len(self._buf) % (2 * self._num_channels) != 0: logger.warning("AudioByteStream: incomplete frame during flush, dropping") return [] diff --git a/livekit-agents/livekit/agents/utils/misc.py b/livekit-agents/livekit/agents/utils/misc.py index f85ae15b7..24956abbf 100644 --- a/livekit-agents/livekit/agents/utils/misc.py +++ b/livekit-agents/livekit/agents/utils/misc.py @@ -2,46 +2,6 @@ import time import uuid -from typing import List, Union - -from livekit import rtc - -AudioBuffer = Union[List[rtc.AudioFrame], rtc.AudioFrame] - - -def merge_frames(buffer: AudioBuffer) -> rtc.AudioFrame: - """ - Merges one or more AudioFrames into a single one - Args: - buffer: either a rtc.AudioFrame or a list of rtc.AudioFrame - """ - if isinstance(buffer, list): - # merge all frames into one - if len(buffer) == 0: - raise ValueError("buffer is empty") - - sample_rate = buffer[0].sample_rate - num_channels = buffer[0].num_channels - samples_per_channel = 0 - data = b"" - for frame in buffer: - if frame.sample_rate != sample_rate: - raise ValueError("sample rate mismatch") - - if frame.num_channels != num_channels: - raise ValueError("channel count mismatch") - - data += frame.data - samples_per_channel += frame.samples_per_channel - - return rtc.AudioFrame( - data=data, - sample_rate=sample_rate, - num_channels=num_channels, - samples_per_channel=samples_per_channel, - ) - - return buffer def time_ms() -> int: diff --git a/livekit-agents/livekit/agents/vad.py b/livekit-agents/livekit/agents/vad.py index ea42e9158..c3c90dd1e 100644 --- a/livekit-agents/livekit/agents/vad.py +++ b/livekit-agents/livekit/agents/vad.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio from abc import ABC, abstractmethod from dataclasses import dataclass, field @@ -18,26 +20,42 @@ class VADEventType(str, Enum): @dataclass class VADEvent: + """ + Represents an event detected by the Voice Activity Detector (VAD). + """ + type: VADEventType - """type of the event""" + """Type of the VAD event (e.g., start of speech, end of speech, inference done).""" + samples_index: int - """index of the samples when the event was fired""" + """Index of the audio sample where the event occurred, relative to the inference sample rate.""" + + timestamp: float + """Timestamp (in seconds) when the event was fired.""" + speech_duration: float - """duration of the speech in seconds""" + """Duration of the detected speech segment in seconds.""" + silence_duration: float - """duration of the silence in seconds""" + """Duration of the silence segment preceding or following the speech, in seconds.""" + frames: List[rtc.AudioFrame] = field(default_factory=list) - """list of audio frames of the speech + """ + List of audio frames associated with the speech. - start_of_speech: contains the complete audio chunks that triggered the detection) - end_of_speech: contains the complete user speech + - For `start_of_speech` events, this contains the audio chunks that triggered the detection. + - For `inference_done` events, this contains the audio chunks that were processed. + - For `end_of_speech` events, this contains the complete user speech. """ + probability: float = 0.0 - """smoothed probability of the speech (only for INFERENCE_DONE event)""" + """Probability that speech is present (only for `INFERENCE_DONE` events).""" + inference_duration: float = 0.0 - """duration of the inference in seconds (only for INFERENCE_DONE event)""" + """Time taken to perform the inference, in seconds (only for `INFERENCE_DONE` events).""" + speaking: bool = False - """whether speech was detected in the frames""" + """Indicates whether speech was detected in the frames.""" @dataclass diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py index 42dda810e..117fabc43 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py @@ -34,8 +34,8 @@ GroqChatModels, OctoChatModels, PerplexityChatModels, + TelnyxChatModels, TogetherChatModels, - TelnyxChatModels ) from .utils import AsyncAzureADTokenProvider, build_oai_message diff --git a/livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py b/livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py index 6e0aaf78c..1c1995e94 100644 --- a/livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py +++ b/livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations +from __future__ import annotations, print_function import asyncio +import math import time from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass +from typing import Literal import numpy as np import onnxruntime # type: ignore @@ -41,6 +43,12 @@ class _VADOptions: class VAD(agents.vad.VAD): + """ + Silero Voice Activity Detection (VAD) class. + + This class provides functionality to detect speech segments within audio data using the Silero VAD model. + """ + @classmethod def load( cls, @@ -50,24 +58,52 @@ def load( prefix_padding_duration: float = 0.1, max_buffered_speech: float = 60.0, activation_threshold: float = 0.5, - sample_rate: int = 16000, + sample_rate: Literal[8000, 16000] = 16000, force_cpu: bool = True, # deprecated padding_duration: float | None = None, ) -> "VAD": """ - Initialize the Silero VAD. + Load and initialize the Silero VAD model. + + This method loads the ONNX model and prepares it for inference. When options are not provided, + sane defaults are used. + + **Note:** + This method is blocking and may take time to load the model into memory. + It is recommended to call this method inside your prewarm mechanism. + + **Example:** + + ```python + def prewarm(proc: JobProcess): + proc.userdata["vad"] = silero.VAD.load() + - When options are not provided, sane defaults are used. + async def entrypoint(ctx: JobContext): + vad = (ctx.proc.userdata["vad"],) + # your agent logic... + + + if __name__ == "__main__": + cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint, prewarm_fnc=prewarm)) + ``` Args: - min_speech_duration: minimum duration of speech to start a new speech chunk - min_silence_duration: In the end of each speech, wait min_silence_duration before ending the speech - prefix_padding_duration: duration of padding to add to the beginning of each speech chunk - max_buffered_speech: maximum duration of speech to keep in the buffer (in seconds) - activation_threshold: threshold to consider a frame as speech - sample_rate: sample rate for the inference (only 8KHz and 16KHz are supported) - force_cpu: force to use CPU for inference + min_speech_duration (float): Minimum duration of speech to start a new speech chunk. + min_silence_duration (float): At the end of each speech, wait this duration before ending the speech. + prefix_padding_duration (float): Duration of padding to add to the beginning of each speech chunk. + max_buffered_speech (float): Maximum duration of speech to keep in the buffer (in seconds). + activation_threshold (float): Threshold to consider a frame as speech. + sample_rate (Literal[8000, 16000]): Sample rate for the inference (only 8KHz and 16KHz are supported). + force_cpu (bool): Force the use of CPU for inference. + padding_duration (float | None): **Deprecated**. Use `prefix_padding_duration` instead. + + Returns: + VAD: An instance of the VAD class ready for streaming. + + Raises: + ValueError: If an unsupported sample rate is provided. """ if sample_rate not in onnx_model.SUPPORTED_SAMPLE_RATES: raise ValueError("Silero VAD only supports 8KHz and 16KHz sample rates") @@ -100,6 +136,12 @@ def __init__( self._opts = opts def stream(self) -> "VADStream": + """ + Create a new VADStream for processing audio data. + + Returns: + VADStream: A stream object for processing audio input and detecting speech. + """ return VADStream( self._opts, onnx_model.OnnxModel( @@ -122,19 +164,11 @@ def __init__(self, opts: _VADOptions, model: onnx_model.OnnxModel) -> None: @agents.utils.log_exceptions(logger=logger) async def _main_task(self): - og_sample_rate = 0 - og_needed_samples = 0 # needed samples to complete the window data - og_window_size_samples = 0 # size in samples of og_window_data - og_padding_size_samples = 0 # size in samples of padding data - og_window_data: np.ndarray | None = None - - index_step = 0 - inference_window_data = np.empty( - self._model.window_size_samples, dtype=np.float32 - ) + inference_f32_data = np.empty(self._model.window_size_samples, dtype=np.float32) # a copy is exposed to the user in END_OF_SPEECH speech_buffer: np.ndarray | None = None + speech_buffer_max_reached = False speech_buffer_index: int = 0 # "pub_" means public, these values are exposed to the users through events @@ -142,87 +176,117 @@ async def _main_task(self): pub_speech_duration = 0.0 pub_silence_duration = 0.0 pub_current_sample = 0 + pub_timestamp = 0.0 + + pub_sample_rate = 0 + pub_prefix_padding_samples = 0 # size in samples of padding data speech_threshold_duration = 0.0 silence_threshold_duration = 0.0 - async for frame in self._input_ch: - if not isinstance(frame, rtc.AudioFrame): - continue # ignore flush sentinel for now + input_frames = [] + inference_frames = [] + resampler: rtc.AudioResampler | None = None - if frame.sample_rate != 8000 and frame.sample_rate % 16000 != 0: - logger.error("only 8KHz and 16KHz*X sample rates are supported") - continue + # used to avoid drift when the sample_rate ratio is not an integer + input_copy_remaining_fract = 0.0 - if og_window_data is None: - # alloc the og buffers now that we know the pushed sample rate - og_sample_rate = frame.sample_rate + async for input_frame in self._input_ch: + if not isinstance(input_frame, rtc.AudioFrame): + continue # ignore flush sentinel for now - og_window_size_samples = int( - (self._model.window_size_samples / self._model.sample_rate) - * og_sample_rate - ) - og_padding_size_samples = int( - self._opts.prefix_padding_duration * og_sample_rate + if not pub_sample_rate or speech_buffer is None: + pub_sample_rate = input_frame.sample_rate + + # alloc the buffers now that we know the input sample rate + pub_prefix_padding_samples = math.ceil( + self._opts.prefix_padding_duration * pub_sample_rate ) - og_window_data = np.empty(og_window_size_samples, dtype=np.int16) - og_needed_samples = og_window_size_samples - index_step = frame.sample_rate // 16000 speech_buffer = np.empty( - int(self._opts.max_buffered_speech * og_sample_rate) - + int(self._opts.prefix_padding_duration * og_sample_rate), + int(self._opts.max_buffered_speech * pub_sample_rate) + + int(self._opts.prefix_padding_duration * pub_sample_rate), dtype=np.int16, ) - if og_sample_rate != frame.sample_rate: + if pub_sample_rate != self._opts.sample_rate: + # resampling needed: the input sample rate isn't the same as the model's + # sample rate used for inference + resampler = rtc.AudioResampler( + input_rate=pub_sample_rate, + output_rate=self._opts.sample_rate, + quality=rtc.AudioResamplerQuality.QUICK, # VAD doesn't need high quality + ) + + elif pub_sample_rate != input_frame.sample_rate: logger.error("a frame with another sample rate was already pushed") continue - frame_data = np.frombuffer(frame.data, dtype=np.int16) - remaining_samples = len(frame_data) + input_frames.append(input_frame) + if resampler is not None: + # the resampler may have a bit of latency, but it is OK to ignore since it should be + # negligible + inference_frames.extend(resampler.push(input_frame)) + else: + inference_frames.append(input_frame) - while remaining_samples > 0: - to_copy = min(remaining_samples, og_needed_samples) - - window_index = og_window_size_samples - og_needed_samples - frame_index = len(frame_data) - remaining_samples - og_window_data[window_index : window_index + to_copy] = frame_data[ - frame_index : frame_index + to_copy - ] - - remaining_samples -= to_copy - og_needed_samples -= to_copy + while True: + start_time = time.perf_counter() - if og_needed_samples != 0: - continue + available_inference_samples = sum( + [frame.samples_per_channel for frame in inference_frames] + ) + if available_inference_samples < self._model.window_size_samples: + break # not enough samples to run inference - og_needed_samples = og_window_size_samples + input_frame = utils.combine_frames(input_frames) + inference_frame = utils.combine_frames(inference_frames) - # copy the data to the inference buffer by sampling at each index_step & convert to float + # convert data to f32 np.divide( - og_window_data[::index_step], + inference_frame.data[: self._model.window_size_samples], np.iinfo(np.int16).max, - out=inference_window_data, + out=inference_f32_data, dtype=np.float32, ) # run the inference - start_time = time.perf_counter() - raw_prob = await self._loop.run_in_executor( - self._executor, self._model, inference_window_data + p = await self._loop.run_in_executor( + self._executor, self._model, inference_f32_data ) - - inference_duration = time.perf_counter() - start_time - - prob_change = abs(raw_prob - self._exp_filter.filtered()) - exp = 0.5 if prob_change > 0.25 else 1 - raw_prob = self._exp_filter.apply(exp=exp, sample=raw_prob) + p = self._exp_filter.apply(exp=1.0, sample=p) window_duration = ( self._model.window_size_samples / self._opts.sample_rate ) + pub_current_sample += self._model.window_size_samples + pub_timestamp += window_duration + + resampling_ratio = pub_sample_rate / self._model.sample_rate + to_copy = ( + self._model.window_size_samples * resampling_ratio + + input_copy_remaining_fract + ) + to_copy_int = int(to_copy) + input_copy_remaining_fract = to_copy - to_copy_int + + # copy the inference window to the speech buffer + available_space = len(speech_buffer) - speech_buffer_index + to_copy_buffer = min(self._model.window_size_samples, available_space) + if to_copy_buffer > 0: + speech_buffer[ + speech_buffer_index : speech_buffer_index + to_copy_buffer + ] = input_frame.data[:to_copy_buffer] + speech_buffer_index += to_copy_buffer + elif not speech_buffer_max_reached: + # reached self._opts.max_buffered_speech (padding is included) + speech_buffer_max_reached = True + logger.warning( + "max_buffered_speech reached, ignoring further data for the current speech input" + ) + + inference_duration = time.perf_counter() - start_time self._extra_inference_time = max( 0.0, self._extra_inference_time + inference_duration - window_duration, @@ -233,33 +297,21 @@ async def _main_task(self): extra={"delay": self._extra_inference_time}, ) - pub_current_sample += og_window_size_samples - - def _copy_inference_window(): - nonlocal speech_buffer_index - - available_space = len(speech_buffer) - speech_buffer_index - to_copy = min(og_window_size_samples, available_space) - if to_copy <= 0: - return # max_buffered_speech reached - - speech_buffer[ - speech_buffer_index : speech_buffer_index + to_copy - ] = og_window_data[:to_copy] - speech_buffer_index += to_copy - def _reset_write_cursor(): - nonlocal speech_buffer_index - if speech_buffer_index <= og_padding_size_samples: + nonlocal speech_buffer_index, speech_buffer_max_reached + assert speech_buffer is not None + + if speech_buffer_index <= pub_prefix_padding_samples: return padding_data = speech_buffer[ speech_buffer_index - - og_padding_size_samples : speech_buffer_index + - pub_prefix_padding_samples : speech_buffer_index ] - speech_buffer[:og_padding_size_samples] = padding_data - speech_buffer_index = og_padding_size_samples + speech_buffer[:pub_prefix_padding_samples] = padding_data + speech_buffer_index = pub_prefix_padding_samples + speech_buffer_max_reached = False def _copy_speech_buffer() -> rtc.AudioFrame: # copy the data from speech_buffer @@ -267,14 +319,12 @@ def _copy_speech_buffer() -> rtc.AudioFrame: speech_data = speech_buffer[:speech_buffer_index].tobytes() return rtc.AudioFrame( - sample_rate=og_sample_rate, + sample_rate=pub_sample_rate, num_channels=1, samples_per_channel=speech_buffer_index, data=speech_data, ) - _copy_inference_window() - if pub_speaking: pub_speech_duration += window_duration else: @@ -284,15 +334,24 @@ def _copy_speech_buffer() -> rtc.AudioFrame: agents.vad.VADEvent( type=agents.vad.VADEventType.INFERENCE_DONE, samples_index=pub_current_sample, + timestamp=pub_timestamp, silence_duration=pub_silence_duration, speech_duration=pub_speech_duration, - probability=raw_prob, + probability=p, inference_duration=inference_duration, + frames=[ + rtc.AudioFrame( + data=input_frame.data[:to_copy_int].tobytes(), + sample_rate=pub_sample_rate, + num_channels=1, + samples_per_channel=to_copy_int, + ) + ], speaking=pub_speaking, ) ) - if raw_prob >= self._opts.activation_threshold: + if p >= self._opts.activation_threshold: speech_threshold_duration += window_duration silence_threshold_duration = 0.0 @@ -306,12 +365,14 @@ def _copy_speech_buffer() -> rtc.AudioFrame: agents.vad.VADEvent( type=agents.vad.VADEventType.START_OF_SPEECH, samples_index=pub_current_sample, + timestamp=pub_timestamp, silence_duration=pub_silence_duration, speech_duration=pub_speech_duration, frames=[_copy_speech_buffer()], speaking=True, ) ) + else: silence_threshold_duration += window_duration speech_threshold_duration = 0.0 @@ -332,6 +393,7 @@ def _copy_speech_buffer() -> rtc.AudioFrame: agents.vad.VADEvent( type=agents.vad.VADEventType.END_OF_SPEECH, samples_index=pub_current_sample, + timestamp=pub_timestamp, silence_duration=pub_silence_duration, speech_duration=pub_speech_duration, frames=[_copy_speech_buffer()], @@ -340,3 +402,32 @@ def _copy_speech_buffer() -> rtc.AudioFrame: ) _reset_write_cursor() + + # remove the frames that were used for inference from the input and inference frames + input_frames = [] + inference_frames = [] + + # add the remaining data + if len(input_frame.data) - to_copy_int > 0: + data = input_frame.data[to_copy_int:].tobytes() + input_frames.append( + rtc.AudioFrame( + data=data, + sample_rate=pub_sample_rate, + num_channels=1, + samples_per_channel=len(data) // 2, + ) + ) + + if len(inference_frame.data) - self._model.window_size_samples > 0: + data = inference_frame.data[ + self._model.window_size_samples : + ].tobytes() + inference_frames.append( + rtc.AudioFrame( + data=data, + sample_rate=self._opts.sample_rate, + num_channels=1, + samples_per_channel=len(data) // 2, + ) + ) diff --git a/tests/test_vad.py b/tests/test_vad.py index 15d066571..940d67a06 100644 --- a/tests/test_vad.py +++ b/tests/test_vad.py @@ -21,6 +21,9 @@ async def test_chunks_vad() -> None: start_of_speech_i = 0 end_of_speech_i = 0 + + inference_frames = [] + async for ev in stream: if ev.type == vad.VADEventType.START_OF_SPEECH: with open( @@ -30,6 +33,9 @@ async def test_chunks_vad() -> None: start_of_speech_i += 1 + if ev.type == vad.VADEventType.INFERENCE_DONE: + inference_frames.extend(ev.frames) + if ev.type == vad.VADEventType.END_OF_SPEECH: with open( f"test_vad.end_of_speech_frames_{end_of_speech_i}.wav", "wb" @@ -41,6 +47,9 @@ async def test_chunks_vad() -> None: assert start_of_speech_i > 0, "no start of speech detected" assert start_of_speech_i == end_of_speech_i, "start and end of speech mismatch" + with open("test_vad.inference_frames.wav", "wb") as f: + f.write(utils.make_wav_file(inference_frames)) + async def test_file_vad(): frames, transcript = utils.make_test_audio()