Skip to content

Commit

Permalink
silero: support any sample rate (livekit#805)
Browse files Browse the repository at this point in the history
  • Loading branch information
theomonnom authored Sep 28, 2024
1 parent 688f5e8 commit f0ace90
Show file tree
Hide file tree
Showing 10 changed files with 386 additions and 150 deletions.
6 changes: 6 additions & 0 deletions .changeset/rare-cows-smile.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"livekit-agents": patch
"livekit-plugins-silero": minor
---

silero: support any sample rate
5 changes: 5 additions & 0 deletions .changeset/warm-needles-change.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-plugins-silero": patch
---

silero: add prefix_padding_duration #801
4 changes: 2 additions & 2 deletions livekit-agents/livekit/agents/tts/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from livekit import rtc

from ..utils import aio, misc
from ..utils import aio, audio


@dataclass
Expand Down Expand Up @@ -71,7 +71,7 @@ async def collect(self) -> rtc.AudioFrame:
frames = []
async for ev in self:
frames.append(ev.frame)
return misc.merge_frames(frames)
return audio.merge_frames(frames)

@abstractmethod
async def _main_task(self) -> None: ...
Expand Down
4 changes: 3 additions & 1 deletion livekit-agents/livekit/agents/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from . import aio, audio, codecs, http_context, images
from .audio import AudioBuffer, combine_frames, merge_frames
from .event_emitter import EventEmitter
from .exp_filter import ExpFilter
from .log import log_exceptions
from .misc import AudioBuffer, merge_frames, shortuuid, time_ms
from .misc import shortuuid, time_ms
from .moving_average import MovingAverage

__all__ = [
"AudioBuffer",
"merge_frames",
"combine_frames",
"time_ms",
"shortuuid",
"http_context",
Expand Down
147 changes: 146 additions & 1 deletion livekit-agents/livekit/agents/utils/audio.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,129 @@
from __future__ import annotations

import ctypes
from typing import List, Union

from livekit import rtc

from ..log import logger

AudioBuffer = Union[List[rtc.AudioFrame], rtc.AudioFrame]


def combine_frames(buffer: AudioBuffer) -> rtc.AudioFrame:
"""
Combines one or more `rtc.AudioFrame` objects into a single `rtc.AudioFrame`.
This function concatenates the audio data from multiple frames, ensuring that
all frames have the same sample rate and number of channels. It efficiently
merges the data by preallocating the necessary memory and copying the frame
data without unnecessary reallocations.
Args:
buffer (AudioBuffer): A single `rtc.AudioFrame` or a list of `rtc.AudioFrame`
objects to be combined.
Returns:
rtc.AudioFrame: A new `rtc.AudioFrame` containing the combined audio data.
Raises:
ValueError: If the buffer is empty.
ValueError: If frames have differing sample rates.
ValueError: If frames have differing numbers of channels.
Example:
>>> frame1 = rtc.AudioFrame(
... data=b"\x01\x02", sample_rate=48000, num_channels=2, samples_per_channel=1
... )
>>> frame2 = rtc.AudioFrame(
... data=b"\x03\x04", sample_rate=48000, num_channels=2, samples_per_channel=1
... )
>>> combined_frame = combine_frames([frame1, frame2])
>>> combined_frame.data
b'\x01\x02\x03\x04'
>>> combined_frame.sample_rate
48000
>>> combined_frame.num_channels
2
>>> combined_frame.samples_per_channel
2
"""
if not isinstance(buffer, list):
return buffer

if not buffer:
raise ValueError("buffer is empty")

sample_rate = buffer[0].sample_rate
num_channels = buffer[0].num_channels

total_data_length = 0
total_samples_per_channel = 0

for frame in buffer:
if frame.sample_rate != sample_rate:
raise ValueError(
f"Sample rate mismatch: expected {sample_rate}, got {frame.sample_rate}"
)

if frame.num_channels != num_channels:
raise ValueError(
f"Channel count mismatch: expected {num_channels}, got {frame.num_channels}"
)

total_data_length += len(frame.data)
total_samples_per_channel += frame.samples_per_channel

data = bytearray(total_data_length)
offset = 0
for frame in buffer:
frame_data = frame.data.cast("b")
data[offset : offset + len(frame_data)] = frame_data
offset += len(frame_data)

return rtc.AudioFrame(
data=data,
sample_rate=sample_rate,
num_channels=num_channels,
samples_per_channel=total_samples_per_channel,
)


merge_frames = combine_frames


class AudioByteStream:
"""
Buffer and chunk audio byte data into fixed-size frames.
This class is designed to handle incoming audio data in bytes,
buffering it and producing audio frames of a consistent size.
It is mainly used to easily chunk big or too small audio frames
into a fixed size, helping to avoid processing very small frames
(which can be inefficient) and very large frames (which can cause
latency or processing delays). By normalizing frame sizes, it
facilitates consistent and efficient audio data processing.
"""

def __init__(
self,
sample_rate: int,
num_channels: int,
samples_per_channel: int | None = None,
) -> None:
"""
Initialize an AudioByteStream instance.
Parameters:
sample_rate (int): The audio sample rate in Hz.
num_channels (int): The number of audio channels.
samples_per_channel (int, optional): The number of samples per channel in each frame.
If None, defaults to `sample_rate // 10` (i.e., 100ms of audio data).
The constructor sets up the internal buffer and calculates the size of each frame in bytes.
The frame size is determined by the number of channels, samples per channel, and the size
of each sample (assumed to be 16 bits or 2 bytes).
"""
self._sample_rate = sample_rate
self._num_channels = num_channels

Expand All @@ -25,7 +135,25 @@ def __init__(
)
self._buf = bytearray()

def write(self, data: bytes) -> list[rtc.AudioFrame]:
def push(self, data: bytes) -> list[rtc.AudioFrame]:
"""
Add audio data to the buffer and retrieve fixed-size frames.
Parameters:
data (bytes): The incoming audio data to buffer.
Returns:
list[rtc.AudioFrame]: A list of `AudioFrame` objects of fixed size.
The method appends the incoming data to the internal buffer.
While the buffer contains enough data to form complete frames,
it extracts the data for each frame, creates an `AudioFrame` object,
and appends it to the list of frames to return.
This allows you to feed in variable-sized chunks of audio data
(e.g., from a stream or file) and receive back a list of
fixed-size audio frames ready for processing or transmission.
"""
self._buf.extend(data)

frames = []
Expand All @@ -44,7 +172,24 @@ def write(self, data: bytes) -> list[rtc.AudioFrame]:

return frames

write = push # Alias for the push method.

def flush(self) -> list[rtc.AudioFrame]:
"""
Flush the buffer and retrieve any remaining audio data as a frame.
Returns:
list[rtc.AudioFrame]: A list containing any remaining `AudioFrame` objects.
This method processes any remaining data in the buffer that does not
fill a complete frame. If the remaining data forms a partial frame
(i.e., its size is not a multiple of the expected sample size), a warning is
logged and an empty list is returned. Otherwise, it returns the final
`AudioFrame` containing the remaining data.
Use this method when you have no more data to push and want to ensure
that all buffered audio data has been processed.
"""
if len(self._buf) % (2 * self._num_channels) != 0:
logger.warning("AudioByteStream: incomplete frame during flush, dropping")
return []
Expand Down
40 changes: 0 additions & 40 deletions livekit-agents/livekit/agents/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,6 @@

import time
import uuid
from typing import List, Union

from livekit import rtc

AudioBuffer = Union[List[rtc.AudioFrame], rtc.AudioFrame]


def merge_frames(buffer: AudioBuffer) -> rtc.AudioFrame:
"""
Merges one or more AudioFrames into a single one
Args:
buffer: either a rtc.AudioFrame or a list of rtc.AudioFrame
"""
if isinstance(buffer, list):
# merge all frames into one
if len(buffer) == 0:
raise ValueError("buffer is empty")

sample_rate = buffer[0].sample_rate
num_channels = buffer[0].num_channels
samples_per_channel = 0
data = b""
for frame in buffer:
if frame.sample_rate != sample_rate:
raise ValueError("sample rate mismatch")

if frame.num_channels != num_channels:
raise ValueError("channel count mismatch")

data += frame.data
samples_per_channel += frame.samples_per_channel

return rtc.AudioFrame(
data=data,
sample_rate=sample_rate,
num_channels=num_channels,
samples_per_channel=samples_per_channel,
)

return buffer


def time_ms() -> int:
Expand Down
38 changes: 28 additions & 10 deletions livekit-agents/livekit/agents/vad.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
Expand All @@ -18,26 +20,42 @@ class VADEventType(str, Enum):

@dataclass
class VADEvent:
"""
Represents an event detected by the Voice Activity Detector (VAD).
"""

type: VADEventType
"""type of the event"""
"""Type of the VAD event (e.g., start of speech, end of speech, inference done)."""

samples_index: int
"""index of the samples when the event was fired"""
"""Index of the audio sample where the event occurred, relative to the inference sample rate."""

timestamp: float
"""Timestamp (in seconds) when the event was fired."""

speech_duration: float
"""duration of the speech in seconds"""
"""Duration of the detected speech segment in seconds."""

silence_duration: float
"""duration of the silence in seconds"""
"""Duration of the silence segment preceding or following the speech, in seconds."""

frames: List[rtc.AudioFrame] = field(default_factory=list)
"""list of audio frames of the speech
"""
List of audio frames associated with the speech.
start_of_speech: contains the complete audio chunks that triggered the detection)
end_of_speech: contains the complete user speech
- For `start_of_speech` events, this contains the audio chunks that triggered the detection.
- For `inference_done` events, this contains the audio chunks that were processed.
- For `end_of_speech` events, this contains the complete user speech.
"""

probability: float = 0.0
"""smoothed probability of the speech (only for INFERENCE_DONE event)"""
"""Probability that speech is present (only for `INFERENCE_DONE` events)."""

inference_duration: float = 0.0
"""duration of the inference in seconds (only for INFERENCE_DONE event)"""
"""Time taken to perform the inference, in seconds (only for `INFERENCE_DONE` events)."""

speaking: bool = False
"""whether speech was detected in the frames"""
"""Indicates whether speech was detected in the frames."""


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
GroqChatModels,
OctoChatModels,
PerplexityChatModels,
TelnyxChatModels,
TogetherChatModels,
TelnyxChatModels
)
from .utils import AsyncAzureADTokenProvider, build_oai_message

Expand Down
Loading

0 comments on commit f0ace90

Please sign in to comment.