Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update KITT and other plugins to use end_of_speech field #153

Merged
merged 4 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.9"
python-version: "3.10"

- name: Install dependencies
run: |
Expand All @@ -19,4 +19,3 @@ jobs:

- name: Check format
run: ruff format --check .

19 changes: 10 additions & 9 deletions examples/kitt/kitt.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def __init__(self, ctx: agents.JobContext):
self.chatgpt_plugin = ChatGPTPlugin(
prompt=PROMPT, message_capacity=20, model="gpt-4-1106-preview"
)
self.stt_plugin = STT()
self.stt_plugin = STT(
min_silence_duration=1000,
)
self.tts_plugin = TTS(
model_id="eleven_turbo_v2", sample_rate=ELEVEN_TTS_SAMPLE_RATE
)
Expand Down Expand Up @@ -126,28 +128,27 @@ async def process_track(self, track: rtc.Track):
await stream.flush()

async def process_stt_stream(self, stream):
buffered_text = ""
async for event in stream:
if not event.is_final or self._agent_state != AgentState.LISTENING:
continue
if event.is_final:
buffered_text = " ".join([buffered_text, event.alternatives[0].text])

alt = event.alternatives[0]
text = alt.text
if alt.confidence < 0.75 or text == "":
if not event.end_of_speech:
continue

await self.ctx.room.local_participant.publish_data(
json.dumps(
{
"text": text,
"text": buffered_text,
"timestamp": int(datetime.now().timestamp() * 1000),
}
),
topic="transcription",
)

msg = ChatGPTMessage(role=ChatGPTMessageRole.user, content=text)
msg = ChatGPTMessage(role=ChatGPTMessageRole.user, content=buffered_text)
chatgpt_stream = self.chatgpt_plugin.add_message(msg)
self.ctx.create_task(self.process_chatgpt_result(chatgpt_stream))
buffered_text = ""

async def process_chatgpt_result(self, text_stream):
# ChatGPT is streamed, so we'll flip the state immediately
Expand Down
2 changes: 1 addition & 1 deletion livekit-agents/livekit/agents/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def running(self) -> bool:
return self._running

@property
def api(self) -> api.LiveKitAPI | None:
def api(self) -> Optional[api.LiveKitAPI]:
return self._api


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class STTOptions:
smart_format: bool
endpointing: Optional[str]


class STT(stt.STT):
def __init__(
self,
Expand Down Expand Up @@ -271,6 +272,7 @@ def prerecorded_transcription_to_speech_event(

return stt.SpeechEvent(
is_final=True,
end_of_speech=True,
alternatives=[
stt.SpeechData(
language=language or "",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def recognize_response_to_speech_event(
gg_alts = result.alternatives
return stt.SpeechEvent(
is_final=True,
end_of_speech=True,
alternatives=[
stt.SpeechData(
language=result.language_code,
Expand All @@ -301,6 +302,9 @@ def streaming_recognize_response_to_speech_event(
gg_alts = result.alternatives
return stt.SpeechEvent(
is_final=result.is_final,
# Google STT does not have a separate end_of_speech indicator
# so we'll use is_final
end_of_speech=result.is_final,
alternatives=[
stt.SpeechData(
language=result.language_code,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,6 @@ async def recognize(
def transcription_to_speech_event(transcription) -> stt.SpeechEvent:
return stt.SpeechEvent(
is_final=True,
end_of_speech=True,
alternatives=[stt.SpeechData(text=transcription.text, language="")],
)
2 changes: 1 addition & 1 deletion ruff.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
line-length = 88
indent-width = 4

target-version = "py39"
target-version = "py310"
3 changes: 3 additions & 0 deletions tests/test_stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ async def recognize(stt: agents.stt.STT):
event = await stt.recognize(buffer=frame)
text = event.alternatives[0].text
assert SequenceMatcher(None, text, TEST_AUDIO_TRANSCRIPT).ratio() > 0.9
assert event.is_final
assert event.end_of_speech

async with asyncio.TaskGroup() as group:
for stt in stts:
Expand Down Expand Up @@ -72,6 +74,7 @@ async def stream(stt: agents.stt.STT):
if event.is_final:
text = event.alternatives[0].text
assert SequenceMatcher(None, text, TEST_AUDIO_TRANSCRIPT).ratio() > 0.8
assert event.end_of_speech
break

await stream.aclose()
Expand Down
Loading