Skip to content

Commit

Permalink
Update KITT and other plugins to use end_of_speech field (livekit#153)
Browse files Browse the repository at this point in the history
* Update KITT and other plugins to use end_of_speech field

Tested with KITT. It significantly improves the end of speech behavior so that
we are giving it a 1s wait before starting to process user input.

* ruff on 3.10

* use ruff action

* fixed ruff
  • Loading branch information
davidzhao authored Feb 6, 2024
1 parent dfb7091 commit 72f886b
Show file tree
Hide file tree
Showing 8 changed files with 23 additions and 13 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.9"
python-version: "3.10"

- name: Install dependencies
run: |
Expand All @@ -19,4 +19,3 @@ jobs:

- name: Check format
run: ruff format --check .

19 changes: 10 additions & 9 deletions examples/kitt/kitt.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def __init__(self, ctx: agents.JobContext):
self.chatgpt_plugin = ChatGPTPlugin(
prompt=PROMPT, message_capacity=20, model="gpt-4-1106-preview"
)
self.stt_plugin = STT()
self.stt_plugin = STT(
min_silence_duration=1000,
)
self.tts_plugin = TTS(
model_id="eleven_turbo_v2", sample_rate=ELEVEN_TTS_SAMPLE_RATE
)
Expand Down Expand Up @@ -126,28 +128,27 @@ async def process_track(self, track: rtc.Track):
await stream.flush()

async def process_stt_stream(self, stream):
buffered_text = ""
async for event in stream:
if not event.is_final or self._agent_state != AgentState.LISTENING:
continue
if event.is_final:
buffered_text = " ".join([buffered_text, event.alternatives[0].text])

alt = event.alternatives[0]
text = alt.text
if alt.confidence < 0.75 or text == "":
if not event.end_of_speech:
continue

await self.ctx.room.local_participant.publish_data(
json.dumps(
{
"text": text,
"text": buffered_text,
"timestamp": int(datetime.now().timestamp() * 1000),
}
),
topic="transcription",
)

msg = ChatGPTMessage(role=ChatGPTMessageRole.user, content=text)
msg = ChatGPTMessage(role=ChatGPTMessageRole.user, content=buffered_text)
chatgpt_stream = self.chatgpt_plugin.add_message(msg)
self.ctx.create_task(self.process_chatgpt_result(chatgpt_stream))
buffered_text = ""

async def process_chatgpt_result(self, text_stream):
# ChatGPT is streamed, so we'll flip the state immediately
Expand Down
2 changes: 1 addition & 1 deletion livekit-agents/livekit/agents/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def running(self) -> bool:
return self._running

@property
def api(self) -> api.LiveKitAPI | None:
def api(self) -> Optional[api.LiveKitAPI]:
return self._api


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class STTOptions:
smart_format: bool
endpointing: Optional[str]


class STT(stt.STT):
def __init__(
self,
Expand Down Expand Up @@ -271,6 +272,7 @@ def prerecorded_transcription_to_speech_event(

return stt.SpeechEvent(
is_final=True,
end_of_speech=True,
alternatives=[
stt.SpeechData(
language=language or "",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def recognize_response_to_speech_event(
gg_alts = result.alternatives
return stt.SpeechEvent(
is_final=True,
end_of_speech=True,
alternatives=[
stt.SpeechData(
language=result.language_code,
Expand All @@ -301,6 +302,9 @@ def streaming_recognize_response_to_speech_event(
gg_alts = result.alternatives
return stt.SpeechEvent(
is_final=result.is_final,
# Google STT does not have a separate end_of_speech indicator
# so we'll use is_final
end_of_speech=result.is_final,
alternatives=[
stt.SpeechData(
language=result.language_code,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,6 @@ async def recognize(
def transcription_to_speech_event(transcription) -> stt.SpeechEvent:
return stt.SpeechEvent(
is_final=True,
end_of_speech=True,
alternatives=[stt.SpeechData(text=transcription.text, language="")],
)
2 changes: 1 addition & 1 deletion ruff.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
line-length = 88
indent-width = 4

target-version = "py39"
target-version = "py310"
3 changes: 3 additions & 0 deletions tests/test_stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ async def recognize(stt: agents.stt.STT):
event = await stt.recognize(buffer=frame)
text = event.alternatives[0].text
assert SequenceMatcher(None, text, TEST_AUDIO_TRANSCRIPT).ratio() > 0.9
assert event.is_final
assert event.end_of_speech

async with asyncio.TaskGroup() as group:
for stt in stts:
Expand Down Expand Up @@ -72,6 +74,7 @@ async def stream(stt: agents.stt.STT):
if event.is_final:
text = event.alternatives[0].text
assert SequenceMatcher(None, text, TEST_AUDIO_TRANSCRIPT).ratio() > 0.8
assert event.end_of_speech
break

await stream.aclose()
Expand Down

0 comments on commit 72f886b

Please sign in to comment.