From 70bc0612753dd7d71090a7db8e95cc73afa1a210 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Monnom?= Date: Wed, 3 Jul 2024 03:49:49 +0200 Subject: [PATCH] ai functions improvements follow up (#393) * Update assistant.py * cleanup voice assistant & make it faster (needs improvements on reliability) --- examples/voice-assistant/minimal_assistant.py | 16 +- livekit-agents/livekit/agents/llm/_oai_api.py | 24 +- .../livekit/agents/llm/chat_context.py | 15 +- .../livekit/agents/llm/function_context.py | 29 +- livekit-agents/livekit/agents/llm/llm.py | 49 +- .../livekit/agents/utils/event_emitter.py | 12 +- livekit-agents/livekit/agents/vad.py | 23 +- .../agents/voice_assistant/__init__.py | 4 +- .../agents/voice_assistant/agent_output.py | 219 +++++ .../agents/voice_assistant/assistant.py | 874 ++---------------- .../agents/voice_assistant/call_context.py | 25 + .../voice_assistant/cancellable_source.py | 134 +++ .../agents/voice_assistant/human_input.py | 150 +++ .../livekit/agents/voice_assistant/impl.py | 421 +++++++++ .../livekit/agents/voice_assistant/log.py | 3 + .../livekit/plugins/openai/llm.py | 43 +- .../livekit/plugins/silero/vad.py | 49 +- 17 files changed, 1186 insertions(+), 904 deletions(-) create mode 100644 livekit-agents/livekit/agents/voice_assistant/agent_output.py create mode 100644 livekit-agents/livekit/agents/voice_assistant/call_context.py create mode 100644 livekit-agents/livekit/agents/voice_assistant/cancellable_source.py create mode 100644 livekit-agents/livekit/agents/voice_assistant/human_input.py create mode 100644 livekit-agents/livekit/agents/voice_assistant/impl.py create mode 100644 livekit-agents/livekit/agents/voice_assistant/log.py diff --git a/examples/voice-assistant/minimal_assistant.py b/examples/voice-assistant/minimal_assistant.py index 3a532e0b5..03fb7ee04 100644 --- a/examples/voice-assistant/minimal_assistant.py +++ b/examples/voice-assistant/minimal_assistant.py @@ -12,16 +12,12 @@ async def entrypoint(ctx: JobContext): - initial_ctx = ChatContext( - messages=[ - ChatMessage( - role=ChatRole.SYSTEM, - text=( - "You are a voice assistant created by LiveKit. Your interface with users will be voice. " - "You should use short and concise responses, and avoiding usage of unpronouncable punctuation." - ), - ) - ] + initial_ctx = ChatContext().append( + role="system", + text=( + "You are a voice assistant created by LiveKit. Your interface with users will be voice. " + "You should use short and concise responses, and avoiding usage of unpronouncable punctuation." + ), ) assistant = VoiceAssistant( diff --git a/livekit-agents/livekit/agents/llm/_oai_api.py b/livekit-agents/livekit/agents/llm/_oai_api.py index cfe5abef8..391bd4414 100644 --- a/livekit-agents/livekit/agents/llm/_oai_api.py +++ b/livekit-agents/livekit/agents/llm/_oai_api.py @@ -29,12 +29,12 @@ ] -def create_ai_function_task( +def create_ai_function_info( fnc_ctx: function_context.FunctionContext, tool_call_id: str, fnc_name: str, raw_arguments: str, # JSON string -) -> tuple[asyncio.Task[Any], function_context.CalledFunction]: +) -> function_context.FunctionCallInfo: if fnc_name not in fnc_ctx.ai_functions: raise ValueError(f"AI function {fnc_name} not found") @@ -80,21 +80,11 @@ def create_ai_function_task( sanitized_arguments[arg_info.name] = sanitized_value - func = functools.partial(fnc_info.callable, **sanitized_arguments) - if asyncio.iscoroutinefunction(fnc_info.callable): - task = asyncio.create_task(func()) - else: - task = asyncio.create_task(asyncio.to_thread(func)) - - return ( - task, - function_context.CalledFunction( - tool_call_id=tool_call_id, - raw_arguments=raw_arguments, - function_info=fnc_info, - arguments=sanitized_arguments, - task=task, - ), + return function_context.FunctionCallInfo( + tool_call_id=tool_call_id, + raw_arguments=raw_arguments, + function_info=fnc_info, + arguments=sanitized_arguments, ) diff --git a/livekit-agents/livekit/agents/llm/chat_context.py b/livekit-agents/livekit/agents/llm/chat_context.py index de6705b67..b42fcf2eb 100644 --- a/livekit-agents/livekit/agents/llm/chat_context.py +++ b/livekit-agents/livekit/agents/llm/chat_context.py @@ -40,7 +40,7 @@ class ChatMessage: role: ChatRole name: str | None = None content: str | list[str | ChatImage] | None = None - tool_calls: list[function_context.CalledFunction] | None = None + tool_calls: list[function_context.FunctionCallInfo] | None = None tool_call_id: str | None = None @staticmethod @@ -50,20 +50,21 @@ def create_tool_from_called_function( if not called_function.task.done(): raise ValueError("cannot create a tool result from a running ai function") - content = called_function.task.result() - if called_function.task.exception() is not None: - content = f"Error: {called_function.task.exception}" + try: + content = called_function.task.result() + except BaseException as e: + content = f"Error: {e}" return ChatMessage( role="tool", - name=called_function.function_info.name, + name=called_function.call_info.function_info.name, content=content, - tool_call_id=called_function.tool_call_id, + tool_call_id=called_function.call_info.tool_call_id, ) @staticmethod def create_tool_calls( - called_functions: list[function_context.CalledFunction], + called_functions: list[function_context.FunctionCallInfo], ) -> "ChatMessage": return ChatMessage( role="assistant", diff --git a/livekit-agents/livekit/agents/llm/function_context.py b/livekit-agents/livekit/agents/llm/function_context.py index a189275b4..069a3d78c 100644 --- a/livekit-agents/livekit/agents/llm/function_context.py +++ b/livekit-agents/livekit/agents/llm/function_context.py @@ -15,6 +15,7 @@ from __future__ import annotations import asyncio +import functools import enum import inspect import typing @@ -57,12 +58,38 @@ class FunctionInfo: @dataclass -class CalledFunction: +class FunctionCallInfo: tool_call_id: str function_info: FunctionInfo raw_arguments: str arguments: dict[str, Any] + + def execute(self) -> CalledFunction: + function_info = self.function_info + func = functools.partial(function_info.callable, **self.arguments) + if asyncio.iscoroutinefunction(function_info.callable): + task = asyncio.create_task(func()) + else: + task = asyncio.create_task(asyncio.to_thread(func)) + + called_fnc = CalledFunction(call_info=self, task=task) + + def _on_done(fut): + try: + called_fnc.result = fut.result() + except BaseException as e: + called_fnc.exception = e + + task.add_done_callback(_on_done) + return called_fnc + + +@dataclass +class CalledFunction: + call_info: FunctionCallInfo task: asyncio.Task[Any] + result: Any | None = None + exception: BaseException | None = None def ai_callable( diff --git a/livekit-agents/livekit/agents/llm/llm.py b/livekit-agents/livekit/agents/llm/llm.py index 1a859cce1..b0b69be85 100644 --- a/livekit-agents/livekit/agents/llm/llm.py +++ b/livekit-agents/livekit/agents/llm/llm.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import abc from dataclasses import dataclass, field from typing import AsyncIterator @@ -12,7 +13,7 @@ class ChoiceDelta: role: ChatRole content: str | None = None - tool_calls: list[function_context.CalledFunction] | None = None + tool_calls: list[function_context.FunctionCallInfo] | None = None @dataclass @@ -28,7 +29,7 @@ class ChatChunk: class LLM(abc.ABC): @abc.abstractmethod - async def chat( + def chat( self, *, chat_ctx: ChatContext, @@ -39,24 +40,48 @@ async def chat( class LLMStream(abc.ABC): - def __init__(self) -> None: - self._called_functions: list[function_context.CalledFunction] = [] + def __init__( + self, *, chat_ctx: ChatContext, fnc_ctx: function_context.FunctionContext | None + ) -> None: + self._function_calls_info: list[function_context.FunctionCallInfo] = [] + self._tasks = set[asyncio.Task]() + self._chat_ctx = chat_ctx + self._fnc_ctx = fnc_ctx @property - def called_functions(self) -> list[function_context.CalledFunction]: + def function_calls(self) -> list[function_context.FunctionCallInfo]: """List of called functions from this stream.""" - return self._called_functions + return self._function_calls_info - @abc.abstractmethod - async def gather_function_results( + @property + def chat_ctx(self) -> ChatContext: + """The initial chat context of this stream.""" + return self._chat_ctx + + @property + def fnc_ctx(self) -> function_context.FunctionContext | None: + """The function context of this stream.""" + return self._fnc_ctx + + def execute_functions( self, - ) -> list[function_context.CalledFunction]: ... + ) -> list[function_context.CalledFunction]: + """Run all functions in this stream.""" + called_functions = [] + for fnc_info in self._function_calls_info: + called_fnc = fnc_info.execute() + called_functions.append(called_fnc) + + return called_functions + + async def aclose(self) -> None: + for task in self._tasks: + task.cancel() + + await asyncio.gather(*self._tasks, return_exceptions=True) def __aiter__(self) -> AsyncIterator[ChatChunk]: return self @abc.abstractmethod async def __anext__(self) -> ChatChunk: ... - - @abc.abstractmethod - async def aclose(self) -> None: ... diff --git a/livekit-agents/livekit/agents/utils/event_emitter.py b/livekit-agents/livekit/agents/utils/event_emitter.py index 289084298..5dcc65efb 100644 --- a/livekit-agents/livekit/agents/utils/event_emitter.py +++ b/livekit-agents/livekit/agents/utils/event_emitter.py @@ -5,7 +5,7 @@ class EventEmitter(Generic[T]): def __init__(self) -> None: - self._events: Dict[T, Set[Callable[[Any], None]]] = dict() + self._events: Dict[T, Set[Callable[..., Any]]] = dict() def emit(self, event: T, *args: Any, **kwargs: Any) -> None: if event in self._events: @@ -13,7 +13,7 @@ def emit(self, event: T, *args: Any, **kwargs: Any) -> None: for callback in callables: callback(*args, **kwargs) - def once(self, event: T, callback: Optional[Callable[[Any], None]] = None): + def once(self, event: T, callback: Optional[Callable[..., Any]] = None): if callback is not None: def once_callback(*args: Any, **kwargs: Any): @@ -23,13 +23,13 @@ def once_callback(*args: Any, **kwargs: Any): return self.on(event, once_callback) else: - def decorator(callback: Callable[[Any], None]): + def decorator(callback: Callable[..., Any]): self.once(event, callback) return callback return decorator - def on(self, event: T, callback: Optional[Callable[[Any], None]] = None): + def on(self, event: T, callback: Optional[Callable[..., Any]] = None): if callback is not None: if event not in self._events: self._events[event] = set() @@ -37,12 +37,12 @@ def on(self, event: T, callback: Optional[Callable[[Any], None]] = None): return callback else: - def decorator(callback: Callable[[Any], None]): + def decorator(callback: Callable[..., Any]): self.on(event, callback) return callback return decorator - def off(self, event: T, callback: Callable[[Any], None]) -> None: + def off(self, event: T, callback: Callable[..., Any]) -> None: if event in self._events: self._events[event].remove(callback) diff --git a/livekit-agents/livekit/agents/vad.py b/livekit-agents/livekit/agents/vad.py index db7488258..461af675f 100644 --- a/livekit-agents/livekit/agents/vad.py +++ b/livekit-agents/livekit/agents/vad.py @@ -1,15 +1,16 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field -from enum import Enum +from enum import Enum, unique from typing import AsyncIterator, List from livekit import rtc +@unique class VADEventType(Enum): - START_OF_SPEECH = 1 - INFERENCE_DONE = 2 - END_OF_SPEECH = 3 + START_OF_SPEECH = "start_of_speech" + INFERENCE_DONE = "inference_done" + END_OF_SPEECH = "end_of_speech" @dataclass @@ -17,9 +18,9 @@ class VADEvent: type: VADEventType """type of the event""" samples_index: int - """index of the samples of the event (when the event was fired)""" - duration: float = 0.0 - """duration of the speech in seconds (only for END_SPEAKING event)""" + """index of the samples when the event was fired""" + duration: float + """duration of the speech in seconds""" frames: List[rtc.AudioFrame] = field(default_factory=list) """list of audio frames of the speech""" probability: float = 0.0 @@ -31,6 +32,14 @@ class VADEvent: class VAD(ABC): + def __init__(self, *, update_interval: float) -> None: + self._update_interval = update_interval + + @property + def update_interval(self) -> float: + """interval in seconds to update the VAD model""" + return self._update_interval + @abstractmethod def stream( self, diff --git a/livekit-agents/livekit/agents/voice_assistant/__init__.py b/livekit-agents/livekit/agents/voice_assistant/__init__.py index 8d5bfd594..b63d19f75 100644 --- a/livekit-agents/livekit/agents/voice_assistant/__init__.py +++ b/livekit-agents/livekit/agents/voice_assistant/__init__.py @@ -1,3 +1,3 @@ -from .assistant import AssistantCallContext, VoiceAssistant +from .assistant import VoiceAssistant -__all__ = ["VoiceAssistant", "AssistantCallContext"] +__all__ = ["VoiceAssistant"] diff --git a/livekit-agents/livekit/agents/voice_assistant/agent_output.py b/livekit-agents/livekit/agents/voice_assistant/agent_output.py new file mode 100644 index 000000000..f7f2c7ec3 --- /dev/null +++ b/livekit-agents/livekit/agents/voice_assistant/agent_output.py @@ -0,0 +1,219 @@ +from __future__ import annotations + +import asyncio +import contextlib +from typing import AsyncIterable, Union + +from livekit import rtc + +from .. import aio, transcription, utils +from .. import llm as llm +from .. import tts as text_to_speech +from .cancellable_source import CancellableAudioSource, PlayoutHandle +from .log import logger + +SpeechSource = Union[AsyncIterable[str], str] + + +class SynthesisHandle: + def __init__( + self, + *, + speech_source: SpeechSource, + audio_source: CancellableAudioSource, + tts: text_to_speech.TTS, + transcription_fwd: transcription.TTSSegmentsForwarder | None = None, + ) -> None: + self._speech_source, self._audio_source, self._tts, self._tr_fwd = ( + speech_source, + audio_source, + tts, + transcription_fwd, + ) + self._buf_ch = aio.Chan[rtc.AudioFrame]() + self._play_handle: PlayoutHandle | None = None + self._interrupt_fut = asyncio.Future() + self._collected_text = "" # collected text from the async stream + + @property + def validated(self) -> bool: + return self._play_handle is not None + + @property + def interrupted(self) -> bool: + return self._interrupt_fut.done() + + @property + def collected_text(self) -> str: + return self._collected_text + + @property + def play_handle(self) -> PlayoutHandle | None: + return self._play_handle + + def play(self) -> PlayoutHandle: + """Validate the speech for playout""" + if self.interrupted: + raise RuntimeError("synthesis was interrupted") + + self._play_handle = self._audio_source.play( + self._buf_ch, + ) + return self._play_handle + + def interrupt(self) -> None: + """Interrupt the speech""" + if self._play_handle is not None: + self._play_handle.interrupt() + + self._interrupt_fut.set_result(None) + + +class AgentOutput: + def __init__( + self, + *, + room: rtc.Room, + source: CancellableAudioSource, + llm: llm.LLM, + tts: text_to_speech.TTS, + ) -> None: + self._room, self._source, self._llm, self._tts = room, source, llm, tts + self._tasks = set() + + async def aclose(self) -> None: + for task in self._tasks: + task.cancel() + + await asyncio.gather(*self._tasks, return_exceptions=True) + + def synthesize(self, *, transcript: SpeechSource) -> SynthesisHandle: + transcription_fwd = transcription.TTSSegmentsForwarder( + room=self._room, participant=self._room.local_participant + ) + + handle = SynthesisHandle( + speech_source=transcript, + audio_source=self._source, + tts=self._tts, + transcription_fwd=transcription_fwd, + ) + + task = asyncio.create_task( + self._synthesize_task( + handle, + ) + ) + self._tasks.add(task) + task.add_done_callback(self._tasks.remove) + return handle + + @utils.log_exceptions(logger=logger) + async def _synthesize_task( + self, + handle: SynthesisHandle, + ) -> None: + """Synthesize speech from the source""" + if isinstance(handle._speech_source, str): + co = _str_synthesis_co(handle._speech_source, handle) + else: + co = _stream_synthesis_co(handle._speech_source, handle) + + synth = asyncio.create_task(co) + try: + _ = await asyncio.wait( + [synth, handle._interrupt_fut], return_when=asyncio.FIRST_COMPLETED + ) + finally: + synth.cancel() + with contextlib.suppress(asyncio.CancelledError): + await synth + + +async def _str_synthesis_co( + text: str, + handle: SynthesisHandle, +) -> None: + """synthesize speech from a string""" + if handle._tr_fwd is not None: + handle._tr_fwd.push_text(text) + handle._tr_fwd.mark_text_segment_end() + + # start_time = time.time() + # first_frame = True + # audio_duration = 0.0 + handle._collected_text = text + + try: + async for audio in handle._tts.synthesize(text): + # if first_frame: + # first_frame = False + # dt = time.time() - start_time + # self._log_debug(f"tts first frame in {dt:.2f}s") + + frame = audio.data + # audio_duration += frame.samples_per_channel / frame.sample_rate + + handle._buf_ch.send_nowait(frame) + if handle._tr_fwd is not None: + handle._tr_fwd.push_audio(frame) + + finally: + if handle._tr_fwd is not None: + handle._tr_fwd.mark_audio_segment_end() + handle._buf_ch.close() + # self._log_debug(f"tts finished synthesising {audio_duration:.2f}s of audio") + + +async def _stream_synthesis_co( + streamed_text: AsyncIterable[str], + handle: SynthesisHandle, +) -> None: + """synthesize speech from streamed text""" + + async def _read_generated_audio_task(): + # start_time = time.time() + # first_frame = True + # audio_duration = 0.0 + async for event in tts_stream: + if event.type == text_to_speech.SynthesisEventType.AUDIO: + # if first_frame: + # first_frame = False + # dt = time.time() - start_time + # self._log_debug(f"tts first frame in {dt:.2f}s (streamed)") + + assert event.audio is not None + frame = event.audio.data + # audio_duration += frame.samples_per_channel / frame.sample_rate + if handle._tr_fwd is not None: + handle._tr_fwd.push_audio(frame) + handle._buf_ch.send_nowait(frame) + + # self._log_debug( + # f"tts finished synthesising {audio_duration:.2f}s audio (streamed)" + # ) + + # otherwise, stream the text to the TTS + tts_stream = handle._tts.stream() + read_atask = asyncio.create_task(_read_generated_audio_task()) + + try: + async for seg in streamed_text: + handle._collected_text = seg + if handle._tr_fwd is not None: + handle._tr_fwd.push_text(seg) + + tts_stream.push_text(seg) + + finally: + if handle._tr_fwd is not None: + handle._tr_fwd.mark_text_segment_end() + + tts_stream.mark_segment_end() + await tts_stream.aclose() + await read_atask + + if handle._tr_fwd is not None: + handle._tr_fwd.mark_audio_segment_end() + + handle._buf_ch.close() diff --git a/livekit-agents/livekit/agents/voice_assistant/assistant.py b/livekit-agents/livekit/agents/voice_assistant/assistant.py index 5c8d4eaab..9975c5b97 100644 --- a/livekit-agents/livekit/agents/voice_assistant/assistant.py +++ b/livekit-agents/livekit/agents/voice_assistant/assistant.py @@ -1,89 +1,19 @@ from __future__ import annotations import asyncio -import contextlib -import contextvars -import logging -import time -from dataclasses import dataclass -from typing import Any, AsyncIterable, Callable, Literal +from typing import Any, AsyncIterable, Awaitable, Callable, Literal, Optional, Union from livekit import rtc -from .. import aio, tokenize, transcription, utils -from .. import llm as allm -from .. import stt as astt -from .. import tts as atts -from .. import vad as avad -from . import plotter +from .. import llm, stt, tokenize, utils, vad +from .. import tts as text_to_speech +from . import impl -logger = logging.getLogger("livekit.agents.voice_assistant") - -@dataclass -class _SpeechData: - source: str | allm.LLMStream | AsyncIterable[str] - allow_interruptions: bool - add_to_ctx: bool # should this synthesis be added to the chat context - validation_future: asyncio.Future[None] # validate the speech for playout - validated: bool = False - interrupted: bool = False - user_question: str | None = None - collected_text: str = "" - - def validate_speech(self) -> None: - self.validated = True - with contextlib.suppress(asyncio.InvalidStateError): - self.validation_future.set_result(None) - - -@dataclass(frozen=True) -class _AssistantOptions: - plotting: bool - debug: bool - allow_interruptions: bool - int_speech_duration: float - int_min_words: int - base_volume: float - transcription: bool - preemptive_synthesis: bool - word_tokenizer: tokenize.WordTokenizer - sentence_tokenizer: tokenize.SentenceTokenizer - hyphenate_word: Callable[[str], list[str]] - transcription_speed: float - - -@dataclass(frozen=True) -class _StartArgs: - room: rtc.Room - participant: rtc.RemoteParticipant | str | None - - -_ContextVar = contextvars.ContextVar["AssistantContext"]("voice_assistant_contextvar") - - -class AssistantCallContext: - def __init__(self, assistant: "VoiceAssistant", llm_stream: allm.LLMStream) -> None: - self._assistant = assistant - self._metadata = dict[str, Any]() - self._llm_stream = llm_stream - - @staticmethod - def get_current() -> "AssistantCallContext": - return _ContextVar.get() - - @property - def assistant(self) -> "VoiceAssistant": - return self._assistant - - def store_metadata(self, key: str, value: Any) -> None: - self._metadata[key] = value - - def get_metadata(self, key: str, default: Any = None) -> Any: - return self._metadata.get(key, default) - - def llm_stream(self) -> allm.LLMStream: - return self._llm_stream +async def _default_will_create_llm_stream( + assistant: VoiceAssistant, chat_ctx: llm.ChatContext +) -> llm.LLMStream: + return assistant.llm.chat(chat_ctx=chat_ctx, fnc_ctx=assistant.fnc_ctx) EventTypes = Literal[ @@ -98,97 +28,77 @@ def llm_stream(self) -> allm.LLMStream: "function_calls_finished", ] +WillCreateLLMStream = Callable[ + ["VoiceAssistant", llm.ChatContext], + Union[Optional[llm.LLMStream], Awaitable[Optional[llm.LLMStream]]], +] + class VoiceAssistant(utils.EventEmitter[EventTypes]): def __init__( self, *, - vad: avad.VAD, - stt: astt.STT, - llm: allm.LLM, - tts: atts.TTS, - chat_ctx: allm.ChatContext | None = None, - fnc_ctx: allm.FunctionContext | None = None, + vad: vad.VAD, + stt: stt.STT, + llm: llm.LLM, + tts: text_to_speech.TTS, + chat_ctx: llm.ChatContext = llm.ChatContext(), + fnc_ctx: llm.FunctionContext | None = None, allow_interruptions: bool = True, interrupt_speech_duration: float = 0.65, interrupt_min_words: int = 3, - base_volume: float = 1.0, - debug: bool = False, - plotting: bool = False, preemptive_synthesis: bool = True, - loop: asyncio.AbstractEventLoop | None = None, transcription: bool = True, + will_create_llm_stream: WillCreateLLMStream = _default_will_create_llm_stream, sentence_tokenizer: tokenize.SentenceTokenizer = tokenize.basic.SentenceTokenizer(), word_tokenizer: tokenize.WordTokenizer = tokenize.basic.WordTokenizer(), hyphenate_word: Callable[[str], list[str]] = tokenize.basic.hyphenate_word, transcription_speed: float = 3.83, + debug: bool = False, + plotting: bool = False, + loop: asyncio.AbstractEventLoop | None = None, ) -> None: super().__init__() - self._loop = loop or asyncio.get_event_loop() - self._opts = _AssistantOptions( + loop = loop or asyncio.get_event_loop() + + def will_create_llm_stream_impl( + _: impl.AssistantImpl, chat_ctx: llm.ChatContext + ): + return will_create_llm_stream(self, chat_ctx) + + opts = impl.ImplOptions( plotting=plotting, debug=debug, allow_interruptions=allow_interruptions, int_speech_duration=interrupt_speech_duration, int_min_words=interrupt_min_words, - base_volume=base_volume, preemptive_synthesis=preemptive_synthesis, transcription=transcription, sentence_tokenizer=sentence_tokenizer, word_tokenizer=word_tokenizer, hyphenate_word=hyphenate_word, transcription_speed=transcription_speed, + will_create_llm_stream=will_create_llm_stream_impl, ) - # wrap with adapter automatically with default options + # wrap with StreamAdapter automatically when streaming is not supported on a specific TTS # to override StreamAdapter options, create the adapter manually if not tts.streaming_supported: - tts = atts.StreamAdapter( + tts = text_to_speech.StreamAdapter( tts=tts, sentence_tokenizer=tokenize.basic.SentenceTokenizer() ) - self._vad, self._tts, self._llm, self._stt = vad, tts, llm, stt - self._fnc_ctx = fnc_ctx - self._chat_ctx = chat_ctx or allm.ChatContext() - self._plotter = plotter.AssistantPlotter(self._loop) - - self._audio_source: rtc.AudioSource | None = None # published agent audiotrack - self._user_track: rtc.RemoteAudioTrack | None = None # user microphone track - self._user_identity: str | None = None # linked participant identity - - self._started = False - self._start_speech_lock = asyncio.Lock() - self._pending_validation = False - - # tasks - self._recognize_atask: asyncio.Task[None] | None = None - self._play_atask: asyncio.Task[None] | None = None - self._tasks = set[asyncio.Task[Any]]() - - # playout state - self._maybe_answer_task: asyncio.Task[None] | None = None - self._validated_speech: _SpeechData | None = None - self._answer_speech: _SpeechData | None = None - self._playout_start_time: float | None = None - - # synthesis state - self._speech_playing: _SpeechData | None = None # validated and playing speech - self._user_speaking, self._agent_speaking = False, False - - self._target_volume = self._opts.base_volume - self._vol_filter = utils.ExpFilter(0.9, max_val=self._opts.base_volume) - self._vol_filter.apply(1.0, self._opts.base_volume) - self._speech_prob = 0.0 - self._transcribed_text, self._interim_text = "", "" - self._ready_future = asyncio.Future[None]() - - @property - def chat_context(self) -> allm.ChatContext: - return self._chat_ctx - - @property - def started(self) -> bool: - return self._started + self._impl = impl.AssistantImpl( + vad=vad, + stt=stt, + llm=llm, + tts=tts, + emitter=self, + options=opts, + chat_ctx=chat_ctx, + fnc_ctx=fnc_ctx, + loop=loop, + ) def start( self, room: rtc.Room, participant: rtc.RemoteParticipant | str | None = None @@ -200,25 +110,14 @@ def start( participant: the participant to listen to, can either be a participant or a participant identity If None, the first participant found in the room will be selected """ - if self.started: - raise RuntimeError("voice assistant already started") - - self._started = True - self._start_args = _StartArgs(room=room, participant=participant) - - room.on("track_published", self._on_track_published) - room.on("track_subscribed", self._on_track_subscribed) - room.on("track_unsubscribed", self._on_track_unsubscribed) - room.on("participant_connected", self._on_participant_connected) - - self._main_atask = asyncio.create_task(self._main_task()) + self._impl.start(room=room, participant=participant) async def say( self, - source: str | allm.LLMStream | AsyncIterable[str], + source: str | llm.LLMStream | AsyncIterable[str], *, allow_interruptions: bool = True, - add_to_chat_context: bool = True, + add_to_chat_ctx: bool = True, ) -> None: """ Make the assistant say something. @@ -227,22 +126,19 @@ async def say( Args: source: the source of the speech allow_interruptions: whether the speech can be interrupted - add_to_chat_context: whether to add the speech to the chat context + add_to_chat_ctx: whether to add the speech to the chat context """ - await self._wait_ready() - - data = _SpeechData( + await self._impl.say( source=source, allow_interruptions=allow_interruptions, - add_to_ctx=add_to_chat_context, - validation_future=asyncio.Future(), + add_to_chat_ctx=add_to_chat_ctx, ) - data.validate_speech() - - await self._start_speech(data, interrupt_current_if_possible=False) - assert self._play_atask is not None - await self._play_atask + async def aclose(self) -> None: + """ + Close the voice assistant + """ + await self._impl.aclose() def on(self, event: EventTypes, callback: Callable[[Any], None] | None = None): """Register a callback for an event @@ -262,646 +158,30 @@ def on(self, event: EventTypes, callback: Callable[[Any], None] | None = None): """ return super().on(event, callback) - async def aclose(self, wait: bool = True) -> None: - """ - Close the voice assistant - - Args: - wait: whether to wait for the current speech to finish before closing - """ - if not self.started: - return - - self._ready_future.cancel() - - self._start_args.room.off("track_published", self._on_track_published) - self._start_args.room.off("track_subscribed", self._on_track_subscribed) - self._start_args.room.off("track_unsubscribed", self._on_track_unsubscribed) - self._start_args.room.off( - "participant_connected", self._on_participant_connected - ) - - self._plotter.terminate() - - with contextlib.suppress(asyncio.CancelledError): - self._main_atask.cancel() - await self._main_atask - - if self._recognize_atask is not None: - self._recognize_atask.cancel() - - if not wait: - if self._play_atask is not None: - self._play_atask.cancel() - - with contextlib.suppress(asyncio.CancelledError): - if self._play_atask is not None: - await self._play_atask - - if self._recognize_atask is not None: - await self._recognize_atask - - @utils.log_exceptions(logger=logger) - async def _main_task(self) -> None: - """ - Main task is publising the agent audio track and run the update loop - """ - if self._opts.plotting: - self._plotter.start() - - if self._start_args.participant is not None: - if isinstance(self._start_args.participant, rtc.RemoteParticipant): - self._link_participant(self._start_args.participant.identity) - else: - self._link_participant(self._start_args.participant) - else: - # no participant provided, try to find the first in the room - for participant in self._start_args.room.participants.values(): - self._link_participant(participant.identity) - break - - self._audio_source = rtc.AudioSource( - self._tts.sample_rate, self._tts.num_channels - ) - - track = rtc.LocalAudioTrack.create_audio_track( - "assistant_voice", self._audio_source - ) - options = rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE) - self._pub = await self._start_args.room.local_participant.publish_track( - track, options - ) - - self._ready_future.set_result(None) - - # Loop running each 10ms to do the following: - # - Update the volume based on the user speech probability - # - Decide when to interrupt the agent speech - # - Decide when to validate the user speech (starting the agent answer) - speech_prob_avg = utils.MovingAverage(100) - speaking_avg_validation = utils.MovingAverage(150) - interruption_speaking_avg = utils.MovingAverage( - int(self._opts.int_speech_duration * 100) - ) - - interval_10ms = aio.interval(0.01) - - vad_pw = 2.4 # TODO(theomonnom): should this be exposed? - while True: - await interval_10ms.tick() - - speech_prob_avg.add_sample(self._speech_prob) - speaking_avg_validation.add_sample(int(self._user_speaking)) - interruption_speaking_avg.add_sample(int(self._user_speaking)) - - bvol = self._opts.base_volume - self._target_volume = max(0, 1 - speech_prob_avg.get_avg() * vad_pw) * bvol - - if self._validated_speech: - if not self._validated_speech.allow_interruptions: - # avoid volume to go to 0 even if speech probability is high - self._target_volume = max(self._target_volume, bvol * 0.5) - - if self._validated_speech.interrupted: - # the current speech is interrupted, target volume should be 0 - self._target_volume = 0 - - if self._user_speaking: - # if the user has been speaking int_speed_duration, interrupt the agent speech - # (this currently allows 10% of noise in the VAD) - if interruption_speaking_avg.get_avg() >= 0.1: - self._interrupt_if_needed() - elif self._pending_validation: - if speaking_avg_validation.get_avg() <= 0.05: - self._validate_answer_if_needed() - - self._plotter.plot_value("raw_vol", self._target_volume) - self._plotter.plot_value("vad_probability", self._speech_prob) - - def _link_participant(self, identity: str) -> None: - p = self._start_args.room.participants_by_identity.get(identity) - assert p is not None, "_link_participant should be called with a valid identity" - - # set self._user_identity before calling _on_track_published or _on_track_subscribed - self._user_identity = identity - self._log_debug(f"linking participant {identity}") - - for pub in p.tracks.values(): - if pub.subscribed: - self._on_track_subscribed(pub.track, pub, p) # type: ignore - else: - self._on_track_published(pub, p) - - def _on_participant_connected(self, participant: rtc.RemoteParticipant): - if not self._user_identity: - self._link_participant(participant.identity) - - def _on_track_published( - self, pub: rtc.RemoteTrackPublication, participant: rtc.RemoteParticipant - ): - if ( - participant.identity != self._user_identity - or pub.source != rtc.TrackSource.SOURCE_MICROPHONE - ): - return - - if not pub.subscribed: - pub.set_subscribed(True) - - def _on_track_subscribed( - self, - track: rtc.RemoteTrack, - pub: rtc.RemoteTrackPublication, - participant: rtc.RemoteParticipant, - ): - if ( - participant.identity != self._user_identity - or pub.source != rtc.TrackSource.SOURCE_MICROPHONE - ): - return - - self._log_debug("starting listening to user microphone") - self._user_track = track # type: ignore - self._recognize_atask = asyncio.create_task( - self._recognize_task(rtc.AudioStream(track)) - ) - - def _on_track_unsubscribed( - self, - track: rtc.RemoteTrack, - pub: rtc.RemoteTrackPublication, - participant: rtc.RemoteParticipant, - ): - if ( - participant.identity != self._user_identity - or pub.source != rtc.TrackSource.SOURCE_MICROPHONE - or self._user_track is None - ): - return - - # user microphone unsubscribed, (participant disconnected/track unpublished) - self._log_debug("user microphone not available anymore") - assert ( - self._recognize_atask is not None - ), "recognize task should be running when user_track was set" - self._recognize_atask.cancel() - self._user_track = None - - @utils.log_exceptions(logger=logger) - async def _recognize_task(self, audio_stream: rtc.AudioStream) -> None: - """ - Receive the frames from the user audio stream and do the following: - - do Voice Activity Detection (VAD) - - do Speech-to-Text (STT) - """ - assert ( - self._user_identity is not None - ), "user identity should be set before recognizing" - - vad_stream = self._vad.stream() - stt_stream = self._stt.stream() - - stt_forwarder: transcription.STTSegmentsForwarder = ( - transcription.NoopSTTSegmentsForwarder() - ) - if self._opts.transcription: - stt_forwarder = transcription.STTSegmentsForwarder( - room=self._start_args.room, - participant=self._user_identity, - track=self._user_track, - ) - - async def _audio_stream_co() -> None: - async for ev in audio_stream: - stt_stream.push_frame(ev.frame) - vad_stream.push_frame(ev.frame) - - async def _vad_stream_co() -> None: - async for ev in vad_stream: - if ev.type == avad.VADEventType.START_OF_SPEECH: - self._log_debug("user started speaking") - self._plotter.plot_event("user_started_speaking") - self._user_speaking = True - self.emit("user_started_speaking") - elif ev.type == avad.VADEventType.INFERENCE_DONE: - self._speech_prob = ev.probability - elif ev.type == avad.VADEventType.END_OF_SPEECH: - self._log_debug(f"user stopped speaking {ev.duration:.2f}s") - self._plotter.plot_event("user_started_speaking") - self._pending_validation = True - self._user_speaking = False - self.emit("user_stopped_speaking") - - async def _stt_stream_co() -> None: - async for ev in stt_stream: - stt_forwarder.update(ev) - if ev.type == astt.SpeechEventType.FINAL_TRANSCRIPT: - self._on_final_transcript(ev.alternatives[0].text) - elif ev.type == astt.SpeechEventType.INTERIM_TRANSCRIPT: - # interim transcript is used in combination with VAD - # to interrupt the current speech. - # (can be disabled by setting int_min_words to 0) - self._interim_text = ev.alternatives[0].text - elif ev.type == astt.SpeechEventType.END_OF_SPEECH: - self._pending_validation = True - - try: - await asyncio.gather( - _audio_stream_co(), - _vad_stream_co(), - _stt_stream_co(), - ) - finally: - await asyncio.gather( - stt_forwarder.aclose(wait=False), - stt_stream.aclose(wait=False), - vad_stream.aclose(), - ) - - def _on_final_transcript(self, text: str) -> None: - self._transcribed_text += text - self._log_debug(f"received final transcript: {self._transcribed_text}") - - # to create an llm stream we need an async context - # setting it to "" and will be updated inside the _answer_task below - # (this function can't be async because we don't want to block _update_co) - self._answer_speech = _SpeechData( - source="", - allow_interruptions=self._opts.allow_interruptions, - add_to_ctx=True, - validation_future=asyncio.Future(), - user_question=self._transcribed_text, - ) - - # this speech may not be validated, so we create a copy - # of our context to add the new user message - copied_ctx = self._chat_ctx.copy() - copied_ctx.messages.append( - allm.ChatMessage( - text=self._transcribed_text, - role=allm.ChatRole.USER, - ) - ) - - if self._maybe_answer_task is not None: - self._maybe_answer_task.cancel() - - async def _answer_task(ctx: allm.ChatContext, data: _SpeechData) -> None: - try: - data.source = await self._llm.chat(ctx, fnc_ctx=self._fnc_ctx) - await self._start_speech(data, interrupt_current_if_possible=False) - except Exception: - logger.exception("error while answering") - - t = asyncio.create_task(_answer_task(copied_ctx, self._answer_speech)) - self._maybe_answer_task = t - self._tasks.add(t) - t.add_done_callback(self._tasks.discard) - - def _interrupt_if_needed(self) -> None: - """ - Check whether the current assistant speech should be interrupted - """ - if ( - not self._validated_speech - or not self._opts.allow_interruptions - or self._validated_speech.interrupted - ): - return - - if self._opts.int_min_words != 0: - txt = self._transcribed_text.strip().split() - if len(txt) <= self._opts.int_min_words: - txt = self._interim_text.strip().split() - if len(txt) <= self._opts.int_min_words: - return - - if ( - self._playout_start_time is not None - and (time.time() - self._playout_start_time) < 1 - ): # don't interrupt new speech (if they're not older than 1s) - return - - self._validated_speech.interrupted = True - self._validate_answer_if_needed() - self._log_debug("user interrupted assistant speech") - - def _validate_answer_if_needed(self) -> None: - """ - Check whether the current pending answer to the user should be validated (played) - """ - if self._answer_speech is None: - return - - if self._agent_speaking and ( - self._validated_speech and not self._validated_speech.interrupted - ): - return - - self._pending_validation = False - self._transcribed_text = self._interim_text = "" - self._answer_speech.validate_speech() - self._log_debug("user speech validated") - - async def _start_speech( - self, data: _SpeechData, *, interrupt_current_if_possible: bool - ) -> None: - await self._wait_ready() - - async with self._start_speech_lock: - # interrupt the current speech if possible, otherwise wait before playing the new speech - if self._play_atask is not None: - if self._validated_speech is not None: - if ( - interrupt_current_if_possible - and self._validated_speech.allow_interruptions - ): - logger.debug("_start_speech - interrupting current speech") - self._validated_speech.interrupted = True - - else: - # pending speech isn't validated yet, OK to cancel - self._play_atask.cancel() - - with contextlib.suppress(asyncio.CancelledError): - await self._play_atask - - self._play_atask = asyncio.create_task( - self._play_speech_if_validated_task(data) - ) - - @utils.log_exceptions(logger=logger) - async def _play_speech_if_validated_task(self, data: _SpeechData) -> None: - """ - Start synthesis and playout the speech only if validated - """ - self._log_debug(f"play_speech_if_validated {data.user_question}") - - # reset volume before starting a new speech - self._vol_filter.reset() - playout_tx = playout_rx = aio.Chan[rtc.AudioFrame]() # playout channel - - tts_forwarder: transcription.TTSSegmentsForwarder = ( - transcription.NoopTTSSegmentsForwarder() - ) - if self._opts.transcription: - tts_forwarder = transcription.TTSSegmentsForwarder( - room=self._start_args.room, - participant=self._start_args.room.local_participant, - track=self._pub.sid, - sentence_tokenizer=self._opts.sentence_tokenizer, - word_tokenizer=self._opts.word_tokenizer, - hyphenate_word=self._opts.hyphenate_word, - speed=self._opts.transcription_speed, - ) - - if not self._opts.preemptive_synthesis: - await data.validation_future - - tts_co = self._synthesize_task(data, playout_tx, tts_forwarder) - _synthesize_task = asyncio.create_task(tts_co) - - try: - # wait for speech validation before playout - await data.validation_future - - # validated! - self._validated_speech = data - self._playout_start_time = time.time() - - if data.user_question is not None: - msg = allm.ChatMessage(text=data.user_question, role=allm.ChatRole.USER) - self._chat_ctx.messages.append(msg) - self.emit("user_speech_committed", self._chat_ctx, msg) - - self._log_debug("starting playout") - await self._playout_co(playout_rx, tts_forwarder) - - msg = allm.ChatMessage( - text=data.collected_text, - role=allm.ChatRole.ASSISTANT, - ) - - if data.add_to_ctx: - self._chat_ctx.messages.append(msg) - if data.interrupted: - self.emit("agent_speech_interrupted", self._chat_ctx, msg) - else: - self.emit("agent_speech_committed", self._chat_ctx, msg) - - self._log_debug("playout finished", extra={"interrupted": data.interrupted}) - finally: - self._validated_speech = None - with contextlib.suppress(asyncio.CancelledError): - _synthesize_task.cancel() - await _synthesize_task - - # make sure that _synthesize_task is finished before closing the transcription - # forwarder. pushing text/audio to the forwarder after closing it will raise an exception - await tts_forwarder.aclose() - self._log_debug("play_speech_if_validated_task finished") - - async def _synthesize_speech_co( - self, - data: _SpeechData, - playout_tx: aio.ChanSender[rtc.AudioFrame], - text: str, - tts_forwarder: transcription.TTSSegmentsForwarder, - ) -> None: - """synthesize speech from a string""" - data.collected_text += text - tts_forwarder.push_text(text) - tts_forwarder.mark_text_segment_end() - - start_time = time.time() - first_frame = True - audio_duration = 0.0 - - try: - async for audio in self._tts.synthesize(text): - if first_frame: - first_frame = False - dt = time.time() - start_time - self._log_debug(f"tts first frame in {dt:.2f}s") - - frame = audio.data - audio_duration += frame.samples_per_channel / frame.sample_rate - - playout_tx.send_nowait(frame) - tts_forwarder.push_audio(frame) - - finally: - tts_forwarder.mark_audio_segment_end() - playout_tx.close() - self._log_debug(f"tts finished synthesising {audio_duration:.2f}s of audio") - - async def _synthesize_streamed_speech_co( - self, - data: _SpeechData, - playout_tx: aio.ChanSender[rtc.AudioFrame], - streamed_text: AsyncIterable[str], - tts_forwarder: transcription.TTSSegmentsForwarder, - ) -> None: - """synthesize speech from streamed text""" - - async def _read_generated_audio_task(): - start_time = time.time() - first_frame = True - audio_duration = 0.0 - async for event in tts_stream: - if event.type == atts.SynthesisEventType.AUDIO: - if first_frame: - first_frame = False - dt = time.time() - start_time - self._log_debug(f"tts first frame in {dt:.2f}s (streamed)") - - assert event.audio is not None - frame = event.audio.data - audio_duration += frame.samples_per_channel / frame.sample_rate - tts_forwarder.push_audio(frame) - playout_tx.send_nowait(frame) - - self._log_debug( - f"tts finished synthesising {audio_duration:.2f}s audio (streamed)" - ) - - # otherwise, stream the text to the TTS - tts_stream = self._tts.stream() - read_task = asyncio.create_task(_read_generated_audio_task()) - - try: - async for seg in streamed_text: - data.collected_text += seg - tts_forwarder.push_text(seg) - tts_stream.push_text(seg) - - finally: - tts_forwarder.mark_text_segment_end() - tts_stream.mark_segment_end() - - await tts_stream.aclose() - await read_task - - tts_forwarder.mark_audio_segment_end() - playout_tx.close() - - @utils.log_exceptions(logger=logger) - async def _synthesize_task( - self, - data: _SpeechData, - playout_tx: aio.ChanSender[rtc.AudioFrame], - tts_forwarder: transcription.TTSSegmentsForwarder, - ) -> None: - """Synthesize speech from the source. Also run LLM inference when needed""" - if isinstance(data.source, str): - await self._synthesize_speech_co( - data, playout_tx, data.source, tts_forwarder - ) - elif isinstance(data.source, allm.LLMStream): - llm_stream = data.source - assistant_ctx = AssistantCallContext(self, llm_stream) - token = _ContextVar.set(assistant_ctx) - - async def _forward_llm_chunks(): - async for chunk in llm_stream: - alt = chunk.choices[0].delta.content - if not alt: - continue - yield alt - - await self._synthesize_streamed_speech_co( - data, playout_tx, _forward_llm_chunks(), tts_forwarder - ) - - if len(llm_stream.called_functions) > 0: - self.emit("function_calls_collected", assistant_ctx) - - await llm_stream.aclose() - - if len(llm_stream.called_functions) > 0: - self.emit("function_calls_finished", assistant_ctx) - - _ContextVar.reset(token) - else: - await self._synthesize_streamed_speech_co( - data, playout_tx, data.source, tts_forwarder - ) - - async def _playout_co( - self, - playout_rx: aio.ChanReceiver[rtc.AudioFrame], - tts_forwarder: transcription.TTSSegmentsForwarder, - ) -> None: - """ - Playout audio with the current volume. - The playout_rx is streaming the synthesized speech from the TTS provider to minimize latency - """ - assert ( - self._audio_source is not None - ), "audio source should be set before playout" - - def _should_break(): - eps = 1e-6 - assert self._validated_speech is not None - return ( - self._validated_speech.interrupted - and self._vol_filter.filtered() <= eps - ) - - first_frame = True - early_break = False - - async for frame in playout_rx: - if first_frame: - self._log_debug("agent started speaking") - self._plotter.plot_event("agent_started_speaking") - self._agent_speaking = True - self.emit("agent_started_speaking") - tts_forwarder.segment_playout_started() # we have only one segment - first_frame = False - - if _should_break(): - early_break = True - break - - # divide the frame by chunks of 20ms - ms20 = frame.sample_rate // 50 - i = 0 - while i < len(frame.data): - if _should_break(): - break - - rem = min(ms20, len(frame.data) - i) - data = frame.data[i : i + rem] - i += rem + @property + def fnc_ctx(self) -> llm.FunctionContext | None: + return self._impl._fnc_ctx - dt = 1 / len(data) - for si in range(0, len(data)): - vol = self._vol_filter.apply(dt, self._target_volume) - data[si] = int((data[si] / 32768) * vol * 32768) + @fnc_ctx.setter + def fnc_ctx(self, fnc_ctx: llm.FunctionContext | None) -> None: + self._impl._fnc_ctx = fnc_ctx - await self._audio_source.capture_frame( - rtc.AudioFrame( - data=data.tobytes(), - sample_rate=frame.sample_rate, - num_channels=frame.num_channels, - samples_per_channel=rem, - ) - ) + @property + def chat_ctx(self) -> llm.ChatContext: + return self._impl._chat_ctx - if not first_frame: - self._log_debug("agent stopped speaking") - if not early_break: - tts_forwarder.segment_playout_finished() + @property + def llm(self) -> llm.LLM: + return self._impl._llm - self._plotter.plot_event("agent_stopped_speaking") - self._agent_speaking = False - self.emit("agent_stopped_speaking") + @property + def tts(self) -> text_to_speech.TTS: + return self._impl._tts - def _log_debug(self, msg: str, **kwargs: Any) -> None: - if self._opts.debug: - logger.debug(msg, **kwargs) + @property + def stt(self) -> stt.STT: + return self._impl._stt - async def _wait_ready(self) -> None: - """Wait for the assistant to be fully started""" - await self._ready_future + @property + def vad(self) -> vad.VAD: + return self._impl._vad diff --git a/livekit-agents/livekit/agents/voice_assistant/call_context.py b/livekit-agents/livekit/agents/voice_assistant/call_context.py new file mode 100644 index 000000000..62b8c7e05 --- /dev/null +++ b/livekit-agents/livekit/agents/voice_assistant/call_context.py @@ -0,0 +1,25 @@ +_ContextVar = contextvars.ContextVar("voice_assistant_contextvar") + + +class AssistantCallContext: + def __init__(self, assistant: "VoiceAssistant", llm_stream: allm.LLMStream) -> None: + self._assistant = assistant + self._metadata = dict() + self._llm_stream = llm_stream + + @staticmethod + def get_current() -> "AssistantCallContext": + return _ContextVar.get() + + @property + def assistant(self) -> "VoiceAssistant": + return self._assistant + + def store_metadata(self, key: str, value: Any) -> None: + self._metadata[key] = value + + def get_metadata(self, key: str, default: Any = None) -> Any: + return self._metadata.get(key, default) + + def llm_stream(self) -> allm.LLMStream: + return self._llm_stream diff --git a/livekit-agents/livekit/agents/voice_assistant/cancellable_source.py b/livekit-agents/livekit/agents/voice_assistant/cancellable_source.py new file mode 100644 index 000000000..a41ead6e8 --- /dev/null +++ b/livekit-agents/livekit/agents/voice_assistant/cancellable_source.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +import asyncio +import contextlib +from typing import AsyncIterable, Literal + +from livekit import rtc + +from .. import utils +from .log import logger + +EventTypes = Literal["playout_started", "playout_stopped"] + + +class PlayoutHandle: + def __init__(self, playout_source: AsyncIterable[rtc.AudioFrame]) -> None: + self._playout_source = playout_source + self._interrupted = False + self._done_fut = asyncio.Future() + + @property + def interrupted(self) -> bool: + return self._interrupted + + @property + def playing(self) -> bool: + return not self._done_fut.done() + + def interrupt(self) -> None: + if not self.playing: + return + + self._interrupted = True + + def __await__(self): + return self._done_fut.__await__() + + +class CancellableAudioSource(utils.EventEmitter[EventTypes]): + def __init__(self, *, source: rtc.AudioSource, alpha: float = 0.95) -> None: + super().__init__() + self._source = source + self._target_volume, self._smoothed_volume = 1.0, 1.0 + self._vol_filter = utils.ExpFilter(alpha=alpha) + self._playout_atask: asyncio.Task[None] | None = None + self._closed = False + + @property + def target_volume(self) -> float: + return self._target_volume + + @target_volume.setter + def target_volume(self, value: float) -> None: + self._target_volume = value + + async def aclose(self) -> None: + if self._closed: + return + + self._closed = True + + if self._playout_atask is not None: + await self._playout_atask + + def play(self, playout_source: AsyncIterable[rtc.AudioFrame]) -> PlayoutHandle: + if self._closed: + raise ValueError("cancellable source is closed") + + handle = PlayoutHandle(playout_source=playout_source) + self._playout_atask = asyncio.create_task( + self._playout_task(self._playout_atask, handle) + ) + return handle + + @utils.log_exceptions(logger=logger) + async def _playout_task( + self, + old_task: asyncio.Task[None] | None, + handle: PlayoutHandle, + ) -> None: + def _should_break(): + eps = 1e-6 + return handle.interrupted and self._vol_filter.filtered() <= eps + + first_frame = True + cancelled = False + + try: + if old_task is not None: + with contextlib.suppress(asyncio.CancelledError): + old_task.cancel() + await old_task + + async for frame in handle._playout_source: + if first_frame: + self.emit("playout_started") + first_frame = False + + if _should_break(): + cancelled = True + break + + # divide the frame by chunks of 20ms + ms20 = frame.sample_rate // 100 + i = 0 + while i < len(frame.data): + if _should_break(): + cancelled = True + break + + print("frame.data", frame.data, "volume", self._vol_filter.filtered()) + + rem = min(ms20, len(frame.data) - i) + data = frame.data[i : i + rem] + i += rem + + tv = self._target_volume if not handle.interrupted else 0.0 + dt = 1 / len(data) + for si in range(0, len(data)): + vol = self._vol_filter.apply(dt, tv) + data[si] = int((data[si] / 32768) * vol * 32768) + + chunk_frame = rtc.AudioFrame( + data=data.tobytes(), + sample_rate=frame.sample_rate, + num_channels=frame.num_channels, + samples_per_channel=rem, + ) + await self._source.capture_frame(chunk_frame) + finally: + if not first_frame: + self.emit("playout_stopped", cancelled) + + handle._done_fut.set_result(None) diff --git a/livekit-agents/livekit/agents/voice_assistant/human_input.py b/livekit-agents/livekit/agents/voice_assistant/human_input.py new file mode 100644 index 000000000..89804d416 --- /dev/null +++ b/livekit-agents/livekit/agents/voice_assistant/human_input.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +import asyncio +import contextlib +from typing import Literal + +from livekit import rtc + +from .. import stt as speech_to_text +from .. import transcription, utils +from .. import vad as voice_activity_detection +from .log import logger + +EventTypes = Literal[ + "start_of_speech", + "vad_inference_done", + "end_of_speech", + "final_transcript", + "interim_transcript", +] + + +class HumanInput(utils.EventEmitter[EventTypes]): + def __init__( + self, + *, + room: rtc.Room, + vad: voice_activity_detection.VAD, + stt: speech_to_text.STT, + participant: rtc.RemoteParticipant, + ) -> None: + super().__init__() + self._room, self._vad, self._stt, self._participant = ( + room, + vad, + stt, + participant, + ) + self._subscribed_track: rtc.RemoteAudioTrack | None = None + self._recognize_atask: asyncio.Task[None] | None = None + + self._closed = False + self._speaking = False + self._speech_probability = 0.0 + + self._room.on("track_published", self._subscribe_to_microphone) + self._room.on("track_subscribed", self._subscribe_to_microphone) + self._subscribe_to_microphone() + + + async def aclose(self) -> None: + if self._closed: + raise RuntimeError("HumanInput already closed") + + self._closed = True + self._room.off("track_published", self._subscribe_to_microphone) + self._room.off("track_subscribed", self._subscribe_to_microphone) + self._speaking = False + + if self._recognize_atask is not None: + self._recognize_atask.cancel() + + with contextlib.suppress(asyncio.CancelledError): + await self._recognize_atask + + @property + def speaking(self) -> bool: + return self._speaking + + @property + def speaking_probability(self) -> float: + return self._speech_probability + + def _subscribe_to_microphone(self, *args, **kwargs) -> None: + """ + Subscribe to the participant microphone if found and not already subscribed. + Do nothing if no track is found. + """ + for publication in self._participant.tracks.values(): + if publication.source != rtc.TrackSource.SOURCE_MICROPHONE: + continue + + if not publication.subscribed: + publication.set_subscribed(True) + + if ( + publication.track is not None + and publication.track != self._subscribed_track + ): + self._subscribed_track = publication.track # type: ignore + if self._recognize_atask is not None: + self._recognize_atask.cancel() + + self._recognize_atask = asyncio.create_task( + self._recognize_task(rtc.AudioStream(self._subscribed_track)) # type: ignore + ) + break + + @utils.log_exceptions(logger=logger) + async def _recognize_task(self, audio_stream: rtc.AudioStream) -> None: + """ + Receive the frames from the user audio stream and detect voice activity. + """ + vad_stream = self._vad.stream() + stt_stream = self._stt.stream() + + stt_forwarder = transcription.STTSegmentsForwarder( + room=self._room, + participant=self._participant, + track=self._subscribed_track, + ) + + async def _audio_stream_co() -> None: + # forward the audio stream to the VAD and STT streams + async for ev in audio_stream: + stt_stream.push_frame(ev.frame) + vad_stream.push_frame(ev.frame) + + async def _vad_stream_co() -> None: + async for ev in vad_stream: + if ev.type == voice_activity_detection.VADEventType.START_OF_SPEECH: + self._speaking = True + self.emit("start_of_speech", ev) + elif ev.type == voice_activity_detection.VADEventType.INFERENCE_DONE: + self._speech_probability = ev.probability + self.emit("vad_inference_done", ev) + elif ev.type == voice_activity_detection.VADEventType.END_OF_SPEECH: + self._speaking = False + self.emit("end_of_speech", ev) + + async def _stt_stream_co() -> None: + async for ev in stt_stream: + stt_forwarder.update(ev) + if ev.type == speech_to_text.SpeechEventType.FINAL_TRANSCRIPT: + self.emit("final_transcript", ev) + elif ev.type == speech_to_text.SpeechEventType.INTERIM_TRANSCRIPT: + self.emit("interim_transcript", ev) + + try: + await asyncio.gather( + _audio_stream_co(), + _vad_stream_co(), + _stt_stream_co(), + ) + finally: + await asyncio.gather( + stt_forwarder.aclose(wait=False), + stt_stream.aclose(wait=False), + vad_stream.aclose(), + ) diff --git a/livekit-agents/livekit/agents/voice_assistant/impl.py b/livekit-agents/livekit/agents/voice_assistant/impl.py new file mode 100644 index 000000000..faecdda3b --- /dev/null +++ b/livekit-agents/livekit/agents/voice_assistant/impl.py @@ -0,0 +1,421 @@ +from __future__ import annotations + +import asyncio +import contextlib +from dataclasses import dataclass +from typing import AsyncIterable, Awaitable, Callable, Optional, Union + +from livekit import rtc + +from .. import aio, llm, stt, tokenize, tts, utils, vad +from .agent_output import AgentOutput, SynthesisHandle +from .cancellable_source import CancellableAudioSource +from .human_input import HumanInput +from .log import logger +from .plotter import AssistantPlotter + + +@dataclass +class _SpeechInfo: + source: str | llm.LLMStream | AsyncIterable[str] + allow_interruptions: bool + add_to_chat_ctx: bool + synthesis_handle: SynthesisHandle + + +WillCreateLLMStream = Callable[ + ["AssistantImpl", llm.ChatContext], + Union[Optional[llm.LLMStream], Awaitable[Optional[llm.LLMStream]]], +] + + +@dataclass(frozen=True) +class ImplOptions: + allow_interruptions: bool + int_speech_duration: float + int_min_words: int + preemptive_synthesis: bool + will_create_llm_stream: WillCreateLLMStream + plotting: bool + debug: bool + + # transcription & transcript analysis + transcription: bool + word_tokenizer: tokenize.WordTokenizer + sentence_tokenizer: tokenize.SentenceTokenizer + hyphenate_word: Callable[[str], list[str]] + transcription_speed: float + + +class AssistantImpl: + UPDATE_INTERVAL_S = 0.5 # 2tps + PLOT_INTERVAL_S = 0.05 # 20tps + + def __init__( + self, + *, + vad: vad.VAD, + stt: stt.STT, + llm: llm.LLM, + tts: tts.TTS, + emitter: utils.EventEmitter, + options: ImplOptions, + chat_ctx: llm.ChatContext, + fnc_ctx: llm.FunctionContext | None, + loop: asyncio.AbstractEventLoop, + ) -> None: + self._stt, self._vad, self._llm, self._tts = stt, vad, llm, tts + self._emitter = emitter + self._opts = options + self._loop = loop + + self._chat_ctx, self._fnc_ctx = chat_ctx, fnc_ctx + self._started, self._closed = False, False + self._plotter = AssistantPlotter(self._loop) + + self._human_input: HumanInput | None = None + self._agent_output: AgentOutput | None = None + self._ready_future = asyncio.Future() + + self._agent_answer_speech: _SpeechInfo | None = None + self._agent_answer_atask: asyncio.Task[None] | None = None + self._agent_playing_speech: _SpeechInfo | None = ( + None # speech currently being played + ) + self._queued_playouts: list[_SpeechInfo] = [] + + self._transcribed_text, self._transcribed_interim_text = "", "" + + def start( + self, room: rtc.Room, participant: rtc.RemoteParticipant | str | None = None + ) -> None: + if self._started: + raise RuntimeError("voice assistant already started") + + room.on("participant_connected", self._on_participant_connected) + self._room, self._participant = room, participant + + if participant is not None: + if isinstance(participant, rtc.RemoteParticipant): + self._link_participant(participant.identity) + else: + self._link_participant(participant) + else: + # no participant provided, try to find the first in the room + for participant in self._room.participants.values(): + self._link_participant(participant.identity) + break + + self._main_atask = asyncio.create_task(self._main_task()) + + async def aclose(self) -> None: + if not self._started: + return + + self._room.off("participant_connected", self._on_participant_connected) + + async def say( + self, + source: str | llm.LLMStream | AsyncIterable[str], + *, + allow_interruptions: bool = True, + add_to_chat_ctx: bool = True, + ) -> None: + await self._ready_future + assert ( + self._agent_output is not None + ), "agent output should be initialized when ready" + + speech_source = source + if isinstance(speech_source, llm.LLMStream): + speech_source = _llm_stream_to_str_iterable(speech_source) + + synthesis_handle = self._agent_output.synthesize(transcript=speech_source) + speech = _SpeechInfo( + source=source, + allow_interruptions=allow_interruptions, + add_to_chat_ctx=add_to_chat_ctx, + synthesis_handle=synthesis_handle, + ) + self._queued_playouts.append(speech) + + def _on_participant_connected(self, participant: rtc.RemoteParticipant): + if self._human_input is not None: + return + + self._link_participant(participant.identity) + + def _link_participant(self, identity: str) -> None: + participant = self._room.participants_by_identity.get(identity) + if participant is None: + logger.error("_link_participant must be called with a valid identity") + return + + self._human_input = HumanInput( + room=self._room, + vad=self._vad, + stt=self._stt, + participant=participant, + ) + + def _on_human_start_of_speech(ev: vad.VADEvent) -> None: + self._plotter.plot_event("user_started_speaking") + self._emitter.emit("user_started_speaking") + + def _on_human_vad_updated(ev: vad.VADEvent) -> None: + tv = max(0, 1 - ev.probability) + self._audio_source.target_volume = tv + + self._plotter.plot_value("raw_vol", tv) + self._plotter.plot_value("vad_probability", ev.probability) + + if ev.duration >= self._opts.int_speech_duration: + self._interrupt_if_needed() + + def _on_human_end_of_speech(ev: vad.VADEvent) -> None: + print("END") + self._validate_answer_if_needed() + self._plotter.plot_event("user_started_speaking") + self._emitter.emit("user_stopped_speaking") + + def _on_human_interim_transcript(ev: stt.SpeechEvent) -> None: + self._transcribed_interim_text = ev.alternatives[0].text + + def _on_human_final_transcript(ev: stt.SpeechEvent) -> None: + self._transcribed_text += ev.alternatives[0].text + + print("received final transcript", self._transcribed_text) + # logger.debug(f"received final transcript: {self._transcribed_text}") + self._synthesize_answer(user_transcript=self._transcribed_text) + + self._human_input.on("start_of_speech", _on_human_start_of_speech) + self._human_input.on("vad_inference_done", _on_human_vad_updated) + self._human_input.on("end_of_speech", _on_human_end_of_speech) + self._human_input.on("interim_transcript", _on_human_interim_transcript) + self._human_input.on("final_transcript", _on_human_final_transcript) + + @utils.log_exceptions(logger=logger) + async def _main_task(self) -> None: + audio_source = rtc.AudioSource(self._tts.sample_rate, self._tts.num_channels) + track = rtc.LocalAudioTrack.create_audio_track("assistant_voice", audio_source) + self._agent_publication = await self._room.local_participant.publish_track( + track, rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE) + ) + + self._audio_source = CancellableAudioSource(source=audio_source) + self._agent_output = AgentOutput( + room=self._room, + source=self._audio_source, + llm=self._llm, + tts=self._tts, + ) + + self._ready_future.set_result(None) + + async def _update_loop_co(): + interval_s = AssistantImpl.UPDATE_INTERVAL_S + interval = aio.interval(interval_s) + while True: + await interval.tick() + + if len(self._queued_playouts) > 0: + speech = self._queued_playouts.pop() + self._agent_playing_speech = speech + await self._play_speech(speech) + self._agent_playing_speech = None + + async def _plotter_co(): + # plot volume and vad probability + + interval_s = AssistantImpl.UPDATE_INTERVAL_S + interval = aio.interval(interval_s) + while True: + await interval.tick() + + coros = [] + coros.append(_update_loop_co()) + if self._opts.plotting: + coros.append(_plotter_co()) + + await asyncio.gather(*coros) + + def _interrupt_if_needed(self) -> None: + """ + Check whether the current assistant speech should be interrupted + """ + if ( + self._agent_playing_speech is None + or not self._agent_playing_speech.allow_interruptions + or self._agent_playing_speech.synthesis_handle.interrupted + ): + return + + if self._opts.int_min_words != 0: + # check the final/interim transcribed text for the minimum word count + # to interrupt the agent speech + final_words = self._opts.word_tokenizer.tokenize( + text=self._transcribed_text + ) + interim_words = self._opts.word_tokenizer.tokenize( + text=self._transcribed_interim_text + ) + if ( + len(final_words) <= self._opts.int_min_words + and len(interim_words) <= self._opts.int_min_words + ): + return + + self._agent_playing_speech.synthesis_handle.interrupt() + + def _validate_answer_if_needed(self) -> None: + """ + Check if the user speech should be validated/played + """ + if self._agent_answer_speech is None or self._human_input is None: + return + + if ( + self._human_input.speaking + or self._agent_answer_speech.synthesis_handle.interrupted + ): + return + + # validate the answer & queue it for playout, also add the user question to the chat context + user_msg = llm.ChatMessage.create(text=self._transcribed_text, role="user") + self._chat_ctx.messages.append(user_msg) + self._emitter.emit("user_speech_committed", self._chat_ctx, user_msg) + + self._agent_playing_synthesis = self._agent_answer_speech + self._agent_answer_speech = None + self._transcribed_text, self._transcribed_interim_text = "", "" + self._queued_playouts.append(self._agent_playing_synthesis) + + print("validate answer") + + def _synthesize_answer(self, *, user_transcript: str): + """ + Synthesize the answer to the user question and make sure only one answer is synthesized at a time + """ + copied_ctx = self._chat_ctx.copy() + copied_ctx.messages.append( + llm.ChatMessage.create(text=user_transcript, role="user") + ) + + if self._agent_answer_speech is not None: + self._agent_answer_speech.synthesis_handle.interrupt() + + self._agent_answer_speech = None + + @utils.log_exceptions(logger=logger) + async def _synthesize_answer_task(old_task: asyncio.Task[None]) -> None: + # Use an async task to synthesize the agent answer to + # allow users to execute async code inside the will_create_llm_stream callback + assert ( + self._agent_output is not None + ), "agent output should be initialized when ready" + + if old_task is not None: + with contextlib.suppress(asyncio.CancelledError): + old_task.cancel() + await old_task + + llm_stream = self._opts.will_create_llm_stream(self, copied_ctx) + if asyncio.iscoroutine(llm_stream): + llm_stream = await llm_stream + + # fallback to default impl if no custom/user stream is returned + if llm_stream is None: + llm_stream = self._llm.chat(chat_ctx=copied_ctx, fnc_ctx=self._fnc_ctx) + + assert isinstance( + llm_stream, llm.LLMStream + ), "will_create_llm_stream should be a LLMStream" + + source = _llm_stream_to_str_iterable(llm_stream) + synthesis = self._agent_output.synthesize(transcript=source) + self._agent_answer_speech = _SpeechInfo( + source=llm_stream, + allow_interruptions=self._opts.allow_interruptions, + add_to_chat_ctx=True, + synthesis_handle=synthesis, + ) + + old_task = self._agent_answer_atask + self._agent_answer_atask = asyncio.create_task( + _synthesize_answer_task(old_task) + ) + + async def _play_speech(self, speech_info: _SpeechInfo) -> None: + assert ( + self._agent_output is not None + ), "agent output should be initialized when ready" + + if speech_info.synthesis_handle.interrupted: + return + + self._playing_synthesis = speech_info.synthesis_handle + play_handle = speech_info.synthesis_handle.play() + + # Wait for the playout of the speech to finish (interrupted or done) + # When the LLM is calling a tool, it doesn't generate any "speech"/"text" to play + # so awaiting the play_handle will end immediately. + print("play_handle", play_handle) + await play_handle + print("end", play_handle) + + collected_text = speech_info.synthesis_handle.collected_text + interrupted = speech_info.synthesis_handle.interrupted + if ( + isinstance(speech_info.source, llm.LLMStream) + and len(speech_info.source.function_calls) > 0 + and not interrupted + ): + # TODO(theomonnom): emit function calls events & add call context + + # run the user functions and automatically generate the LLM answer for it + # when they're all completed + called_fncs = speech_info.source.execute_functions() + tasks = [called_fnc.task for called_fnc in called_fncs] + await asyncio.gather(*tasks, return_exceptions=True) + + tool_calls = [] + tool_calls_results = [] + + for called_fnc in called_fncs: + # ignore the function calls that returns None + if called_fnc.result is None: + continue + + tool_calls.append(called_fnc.call_info) + tool_calls_results.append( + llm.ChatMessage.create_tool_from_called_function(called_fnc) + ) + + chat_ctx = speech_info.source.chat_ctx.copy() + chat_ctx.messages.extend(tool_calls) + chat_ctx.messages.extend(tool_calls_results) + + answer_stream = self._llm.chat(chat_ctx=chat_ctx, fnc_ctx=self._fnc_ctx) + answer_synthesis = self._agent_output.synthesize( + transcript=_llm_stream_to_str_iterable(answer_stream) + ) + await answer_synthesis.play() + + collected_text = answer_synthesis.collected_text + interrupted = answer_synthesis.interrupted + + if speech_info.add_to_chat_ctx: + msg = llm.ChatMessage.create(text=collected_text, role="assistant") + self._chat_ctx.messages.append(msg) + + if interrupted: + self._emitter.emit("agent_speech_interrupted", self._chat_ctx, msg) + else: + self._emitter.emit("agent_speech_committed", self._chat_ctx, msg) + + +async def _llm_stream_to_str_iterable(stream: llm.LLMStream) -> AsyncIterable[str]: + async for chunk in stream: + content = chunk.choices[0].delta.content + if content is not None: + yield content diff --git a/livekit-agents/livekit/agents/voice_assistant/log.py b/livekit-agents/livekit/agents/voice_assistant/log.py new file mode 100644 index 000000000..c9bd553e8 --- /dev/null +++ b/livekit-agents/livekit/agents/voice_assistant/log.py @@ -0,0 +1,3 @@ +import logging + +logger = logging.getLogger("livekit.agents.voice_assistant") diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py index f998fa831..ea809b920 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py @@ -17,7 +17,7 @@ import asyncio import base64 from dataclasses import dataclass -from typing import Any, MutableSet +from typing import Any, Awaitable, MutableSet from livekit import rtc from livekit.agents import llm, utils @@ -48,7 +48,7 @@ def __init__( self._client = client or openai.AsyncClient(base_url=get_base_url(base_url)) self._running_fncs: MutableSet[asyncio.Task[Any]] = set() - async def chat( + def chat( self, *, chat_ctx: llm.ChatContext, @@ -65,7 +65,7 @@ async def chat( opts["tools"] = fncs_desc messages = _build_oai_context(chat_ctx, id(self)) - cmp = await self._client.chat.completions.create( + cmp = self._client.chat.completions.create( messages=messages, model=self._opts.model, n=n, @@ -74,38 +74,36 @@ async def chat( **opts, ) - return LLMStream(cmp, fnc_ctx) + return LLMStream(oai_stream=cmp, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx) class LLMStream(llm.LLMStream): def __init__( self, - oai_stream: openai.AsyncStream[ChatCompletionChunk], + *, + oai_stream: Awaitable[openai.AsyncStream[ChatCompletionChunk]], + chat_ctx: llm.ChatContext, fnc_ctx: llm.FunctionContext | None, ) -> None: - super().__init__() - self._oai_stream = oai_stream - self._fnc_ctx = fnc_ctx - self._running_tasks: MutableSet[asyncio.Task[Any]] = set() + super().__init__(chat_ctx=chat_ctx, fnc_ctx=fnc_ctx) + self._awaitable_oai_stream = oai_stream + self._oai_stream: openai.AsyncStream[ChatCompletionChunk] | None = None # current function call that we're waiting for full completion (args are streamed) self._tool_call_id: str | None = None self._fnc_name: str | None = None self._fnc_raw_arguments: str | None = None - async def gather_function_results(self) -> list[llm.CalledFunction]: - await asyncio.gather(*self._running_tasks, return_exceptions=True) - return self._called_functions - async def aclose(self) -> None: - await self._oai_stream.close() + if self._oai_stream: + await self._oai_stream.close() - for task in self._running_tasks: - task.cancel() - - await asyncio.gather(*self._running_tasks, return_exceptions=True) + return await super().aclose() async def __anext__(self): + if not self._oai_stream: + self._oai_stream = await self._awaitable_oai_stream + async for chunk in self._oai_stream: for choice in chunk.choices: chat_chunk = self._parse_choice(choice) @@ -170,21 +168,18 @@ def _try_run_function(self, choice: Choice) -> llm.ChatChunk | None: ) return None - task, called_function = llm._oai_api.create_ai_function_task( + fnc_info = llm._oai_api.create_ai_function_info( self._fnc_ctx, self._tool_call_id, self._fnc_name, self._fnc_raw_arguments ) self._tool_call_id = self._fnc_name = self._fnc_raw_arguments = None - - self._running_tasks.add(task) - task.add_done_callback(self._running_tasks.remove) - self._called_functions.append(called_function) + self._function_calls_info.append(fnc_info) return llm.ChatChunk( choices=[ llm.Choice( delta=llm.ChoiceDelta( role="assistant", - tool_calls=[called_function], + tool_calls=[fnc_info], ), index=choice.index, ) diff --git a/livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py b/livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py index 6e9870daa..5fb083843 100644 --- a/livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py +++ b/livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py @@ -43,7 +43,7 @@ def __init__( self, *, min_speech_duration: float = 0.260, # 260ms - min_silence_duration: float = 0.15, # 150ms + min_silence_duration: float = 0.5, # 150ms padding_duration: float = 0.1, max_buffered_speech: float = 60.0, activation_threshold: float = 0.35, @@ -194,6 +194,7 @@ async def aclose(self) -> None: @agents.utils.log_exceptions(logger=logger) async def _main_task(self): pub_speaking = False + pub_duration = 0.0 pub_speech_buf = np.array([], dtype=np.int16) may_start_at_sample = -1 @@ -214,42 +215,30 @@ async def _main_task(self): inference_data = window_data.inference_data start_time = time.time() raw_prob = await asyncio.to_thread(lambda: self._model(inference_data)) - raw_speaking = raw_prob >= self._opts.activation_threshold inference_duration = time.time() - start_time window_duration = self._opts.window_size_samples / self._opts.sample_rate if inference_duration > window_duration: # slower than realtime logger.warning( - "vad inference took too long — slower than realtime: %f", + "vad inference took too long - slower than realtime: %f", inference_duration, ) - self._event_ch.send_nowait( - agents.vad.VADEvent( - type=agents.vad.VADEventType.INFERENCE_DONE, - samples_index=current_sample, - probability=raw_prob, - inference_duration=inference_duration, - speaking=raw_speaking, - ) - ) - current_sample += self._opts.window_size_samples - # append new data to current speech buffer pub_speech_buf = np.append(pub_speech_buf, window_data.original_data) - cl = self._opts.padding_duration + max_data_s = self._opts.padding_duration if not pub_speaking: - cl += self._opts.min_speech_duration + max_data_s += self._opts.min_speech_duration else: - cl += self._opts.max_buffered_speech + max_data_s += self._opts.max_buffered_speech - cl = int(cl) * self._original_sample_rate + cl = int(max_data_s) * self._original_sample_rate if len(pub_speech_buf) > cl: pub_speech_buf = pub_speech_buf[-cl:] # dispatch start/end when needed - if raw_speaking: + if raw_prob >= self._opts.activation_threshold: may_end_at_sample = -1 if may_start_at_sample == -1: @@ -260,12 +249,27 @@ async def _main_task(self): self._event_ch.send_nowait( agents.vad.VADEvent( type=agents.vad.VADEventType.START_OF_SPEECH, + duration=0.0, samples_index=current_sample, speaking=True, ) ) - else: + if pub_speaking: + pub_duration += window_duration + + self._event_ch.send_nowait( + agents.vad.VADEvent( + type=agents.vad.VADEventType.INFERENCE_DONE, + samples_index=current_sample, + duration=pub_duration, + probability=raw_prob, + inference_duration=inference_duration, + speaking=pub_speaking, + ) + ) + + if raw_prob < self._opts.activation_threshold: may_start_at_sample = -1 if may_end_at_sample == -1: @@ -285,13 +289,16 @@ async def _main_task(self): agents.vad.VADEvent( type=agents.vad.VADEventType.END_OF_SPEECH, samples_index=current_sample, - duration=len(pub_speech_buf) / self._original_sample_rate, + duration=pub_duration, frames=[frame], speaking=False, ) ) pub_speech_buf = np.array([], dtype=np.int16) + pub_duration = 0 + + current_sample += self._opts.window_size_samples self._event_ch.close()