diff --git a/.changeset/ninety-items-yell.md b/.changeset/ninety-items-yell.md new file mode 100644 index 000000000..7bea9e8fd --- /dev/null +++ b/.changeset/ninety-items-yell.md @@ -0,0 +1,5 @@ +--- +"livekit-agents": patch +--- + +voiceassistant: add VoiceAssistantState diff --git a/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py b/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py index 18d0b8cb2..f0a3dd17c 100644 --- a/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py +++ b/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py @@ -34,6 +34,8 @@ "function_calls_finished", ] +VoiceAssistantState = Literal["initializing", "listening", "thinking", "speaking"] + _CallContextVar = contextvars.ContextVar["AssistantCallContext"]( "voice_assistant_contextvar" @@ -201,6 +203,8 @@ def __init__( self._last_end_of_speech_time: float | None = None + self._update_state_task: asyncio.Task | None = None + @property def fnc_ctx(self) -> FunctionContext | None: return self._fnc_ctx @@ -301,6 +305,23 @@ async def say( new_handle.initialize(source=source, synthesis_handle=synthesis_handle) self._add_speech_for_playout(new_handle) + def _update_state(self, state: VoiceAssistantState, delay: float = 0.0): + """Set the current state of the voice assistant""" + + @utils.log_exceptions(logger=logger) + async def _run_task(delay: float) -> None: + await asyncio.sleep(delay) + + if self._room.isconnected(): + await self._room.local_participant.set_attributes( + {"voice_assistant.state": state} + ) + + if self._update_state_task is not None: + self._update_state_task.cancel() + + self._update_state_task = asyncio.create_task(_run_task(delay)) + async def aclose(self) -> None: """Close the voice assistant""" if not self._started: @@ -333,6 +354,7 @@ def _on_start_of_speech(ev: vad.VADEvent) -> None: self._plotter.plot_event("user_started_speaking") self.emit("user_started_speaking") self._deferred_validation.on_human_start_of_speech(ev) + self._update_state("listening") def _on_vad_updated(ev: vad.VADEvent) -> None: if not self._track_published_fut.done(): @@ -393,6 +415,7 @@ async def _main_task(self) -> None: if self._opts.plotting: await self._plotter.start() + self._update_state("initializing") audio_source = rtc.AudioSource(self._tts.sample_rate, self._tts.num_channels) track = rtc.LocalAudioTrack.create_audio_track("assistant_voice", audio_source) self._agent_publication = await self._room.local_participant.publish_track( @@ -410,10 +433,12 @@ async def _main_task(self) -> None: def _on_playout_started() -> None: self._plotter.plot_event("agent_started_speaking") self.emit("agent_started_speaking") + self._update_state("speaking") def _on_playout_stopped(interrupted: bool) -> None: self._plotter.plot_event("agent_stopped_speaking") self.emit("agent_stopped_speaking") + self._update_state("listening") agent_playout.on("playout_started", _on_playout_started) agent_playout.on("playout_stopped", _on_playout_stopped) @@ -439,6 +464,9 @@ def _synthesize_agent_reply(self) -> None: if self._pending_agent_reply is not None: self._pending_agent_reply.interrupt() + if self._human_input is not None and not self._human_input.speaking: + self._update_state("thinking", 0.2) + self._pending_agent_reply = new_handle = SpeechHandle.create_assistant_reply( allow_interruptions=self._opts.allow_interruptions, add_to_chat_ctx=True,