From 72f886b058cf48aa230e60a5bfa66fac8d5bf1e2 Mon Sep 17 00:00:00 2001 From: David Zhao Date: Tue, 6 Feb 2024 09:12:38 -0800 Subject: [PATCH] Update KITT and other plugins to use end_of_speech field (#153) * Update KITT and other plugins to use end_of_speech field Tested with KITT. It significantly improves the end of speech behavior so that we are giving it a 1s wait before starting to process user input. * ruff on 3.10 * use ruff action * fixed ruff --- .github/workflows/ruff.yml | 3 +-- examples/kitt/kitt.py | 19 ++++++++++--------- livekit-agents/livekit/agents/worker.py | 2 +- .../livekit/plugins/deepgram/stt.py | 2 ++ .../livekit/plugins/google/stt.py | 4 ++++ .../livekit/plugins/openai/stt.py | 1 + ruff.toml | 2 +- tests/test_stt.py | 3 +++ 8 files changed, 23 insertions(+), 13 deletions(-) diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 21fb3c4a8..809d0a287 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -7,7 +7,7 @@ jobs: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: - python-version: "3.9" + python-version: "3.10" - name: Install dependencies run: | @@ -19,4 +19,3 @@ jobs: - name: Check format run: ruff format --check . - diff --git a/examples/kitt/kitt.py b/examples/kitt/kitt.py index 0a5ae537c..08707a9a4 100644 --- a/examples/kitt/kitt.py +++ b/examples/kitt/kitt.py @@ -66,7 +66,9 @@ def __init__(self, ctx: agents.JobContext): self.chatgpt_plugin = ChatGPTPlugin( prompt=PROMPT, message_capacity=20, model="gpt-4-1106-preview" ) - self.stt_plugin = STT() + self.stt_plugin = STT( + min_silence_duration=1000, + ) self.tts_plugin = TTS( model_id="eleven_turbo_v2", sample_rate=ELEVEN_TTS_SAMPLE_RATE ) @@ -126,28 +128,27 @@ async def process_track(self, track: rtc.Track): await stream.flush() async def process_stt_stream(self, stream): + buffered_text = "" async for event in stream: - if not event.is_final or self._agent_state != AgentState.LISTENING: - continue + if event.is_final: + buffered_text = " ".join([buffered_text, event.alternatives[0].text]) - alt = event.alternatives[0] - text = alt.text - if alt.confidence < 0.75 or text == "": + if not event.end_of_speech: continue - await self.ctx.room.local_participant.publish_data( json.dumps( { - "text": text, + "text": buffered_text, "timestamp": int(datetime.now().timestamp() * 1000), } ), topic="transcription", ) - msg = ChatGPTMessage(role=ChatGPTMessageRole.user, content=text) + msg = ChatGPTMessage(role=ChatGPTMessageRole.user, content=buffered_text) chatgpt_stream = self.chatgpt_plugin.add_message(msg) self.ctx.create_task(self.process_chatgpt_result(chatgpt_stream)) + buffered_text = "" async def process_chatgpt_result(self, text_stream): # ChatGPT is streamed, so we'll flip the state immediately diff --git a/livekit-agents/livekit/agents/worker.py b/livekit-agents/livekit/agents/worker.py index 89b5bf948..f0176ff61 100644 --- a/livekit-agents/livekit/agents/worker.py +++ b/livekit-agents/livekit/agents/worker.py @@ -311,7 +311,7 @@ def running(self) -> bool: return self._running @property - def api(self) -> api.LiveKitAPI | None: + def api(self) -> Optional[api.LiveKitAPI]: return self._api diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index 9c9de53f8..1783b8e0c 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -24,6 +24,7 @@ class STTOptions: smart_format: bool endpointing: Optional[str] + class STT(stt.STT): def __init__( self, @@ -271,6 +272,7 @@ def prerecorded_transcription_to_speech_event( return stt.SpeechEvent( is_final=True, + end_of_speech=True, alternatives=[ stt.SpeechData( language=language or "", diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/stt.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/stt.py index 937763132..7a924e66b 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/stt.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/stt.py @@ -281,6 +281,7 @@ def recognize_response_to_speech_event( gg_alts = result.alternatives return stt.SpeechEvent( is_final=True, + end_of_speech=True, alternatives=[ stt.SpeechData( language=result.language_code, @@ -301,6 +302,9 @@ def streaming_recognize_response_to_speech_event( gg_alts = result.alternatives return stt.SpeechEvent( is_final=result.is_final, + # Google STT does not have a separate end_of_speech indicator + # so we'll use is_final + end_of_speech=result.is_final, alternatives=[ stt.SpeechData( language=result.language_code, diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/stt.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/stt.py index 141921c29..457fec27f 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/stt.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/stt.py @@ -94,5 +94,6 @@ async def recognize( def transcription_to_speech_event(transcription) -> stt.SpeechEvent: return stt.SpeechEvent( is_final=True, + end_of_speech=True, alternatives=[stt.SpeechData(text=transcription.text, language="")], ) diff --git a/ruff.toml b/ruff.toml index 79dc5067c..cce6f04e7 100644 --- a/ruff.toml +++ b/ruff.toml @@ -1,4 +1,4 @@ line-length = 88 indent-width = 4 -target-version = "py39" +target-version = "py310" diff --git a/tests/test_stt.py b/tests/test_stt.py index 37cc1da1a..a2bd597b1 100644 --- a/tests/test_stt.py +++ b/tests/test_stt.py @@ -32,6 +32,8 @@ async def recognize(stt: agents.stt.STT): event = await stt.recognize(buffer=frame) text = event.alternatives[0].text assert SequenceMatcher(None, text, TEST_AUDIO_TRANSCRIPT).ratio() > 0.9 + assert event.is_final + assert event.end_of_speech async with asyncio.TaskGroup() as group: for stt in stts: @@ -72,6 +74,7 @@ async def stream(stt: agents.stt.STT): if event.is_final: text = event.alternatives[0].text assert SequenceMatcher(None, text, TEST_AUDIO_TRANSCRIPT).ratio() > 0.8 + assert event.end_of_speech break await stream.aclose()