Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

voiceassistant: add VoiceAssistantState #654

Merged
merged 13 commits into from
Sep 3, 2024
5 changes: 5 additions & 0 deletions .changeset/ninety-items-yell.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-agents": patch
---

voiceassistant: add VoiceAssistantState
28 changes: 28 additions & 0 deletions livekit-agents/livekit/agents/voice_assistant/voice_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
"function_calls_finished",
]

VoiceAssistantState = Literal["initializing", "listening", "thinking", "speaking"]


_CallContextVar = contextvars.ContextVar["AssistantCallContext"](
"voice_assistant_contextvar"
Expand Down Expand Up @@ -201,6 +203,8 @@ def __init__(

self._last_end_of_speech_time: float | None = None

self._update_state_task: asyncio.Task | None = None

@property
def fnc_ctx(self) -> FunctionContext | None:
return self._fnc_ctx
Expand Down Expand Up @@ -301,6 +305,23 @@ async def say(
new_handle.initialize(source=source, synthesis_handle=synthesis_handle)
self._add_speech_for_playout(new_handle)

def _update_state(self, state: VoiceAssistantState, delay: float = 0.0):
"""Set the current state of the voice assistant"""

@utils.log_exceptions(logger=logger)
async def _run_task(delay: float) -> None:
await asyncio.sleep(delay)

if self._room.isconnected():
await self._room.local_participant.set_attributes(
{"voice_assistant.state": state}
)

if self._update_state_task is not None:
self._update_state_task.cancel()

self._update_state_task = asyncio.create_task(_run_task(delay))

async def aclose(self) -> None:
"""Close the voice assistant"""
if not self._started:
Expand Down Expand Up @@ -333,6 +354,7 @@ def _on_start_of_speech(ev: vad.VADEvent) -> None:
self._plotter.plot_event("user_started_speaking")
self.emit("user_started_speaking")
self._deferred_validation.on_human_start_of_speech(ev)
self._update_state("listening")

def _on_vad_updated(ev: vad.VADEvent) -> None:
if not self._track_published_fut.done():
Expand Down Expand Up @@ -393,6 +415,7 @@ async def _main_task(self) -> None:
if self._opts.plotting:
await self._plotter.start()

self._update_state("initializing")
audio_source = rtc.AudioSource(self._tts.sample_rate, self._tts.num_channels)
track = rtc.LocalAudioTrack.create_audio_track("assistant_voice", audio_source)
self._agent_publication = await self._room.local_participant.publish_track(
Expand All @@ -410,10 +433,12 @@ async def _main_task(self) -> None:
def _on_playout_started() -> None:
self._plotter.plot_event("agent_started_speaking")
self.emit("agent_started_speaking")
self._update_state("speaking")

def _on_playout_stopped(interrupted: bool) -> None:
self._plotter.plot_event("agent_stopped_speaking")
self.emit("agent_stopped_speaking")
self._update_state("listening")

agent_playout.on("playout_started", _on_playout_started)
agent_playout.on("playout_stopped", _on_playout_stopped)
Expand All @@ -439,6 +464,9 @@ def _synthesize_agent_reply(self) -> None:
if self._pending_agent_reply is not None:
self._pending_agent_reply.interrupt()

if self._human_input is not None and not self._human_input.speaking:
self._update_state("thinking", 0.2)

self._pending_agent_reply = new_handle = SpeechHandle.create_assistant_reply(
allow_interruptions=self._opts.allow_interruptions,
add_to_chat_ctx=True,
Expand Down
Loading