From 07862cf4dffb5d966d1dfc504686c5dd932b3367 Mon Sep 17 00:00:00 2001 From: Neil Dwyer Date: Fri, 9 Feb 2024 10:48:02 -0800 Subject: [PATCH] Rewrite Deepgram to use WebSocket API (#156) Merging to create a base for the upcoming STT api changes. --- .../livekit/plugins/deepgram/stt.py | 287 +++++++++++------- .../livekit/plugins/deepgram/version.py | 2 +- .../livekit-plugins-deepgram/setup.py | 2 +- 3 files changed, 172 insertions(+), 119 deletions(-) diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index 1783b8e0c..d365b60eb 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -1,16 +1,23 @@ -from livekit import rtc, agents -from livekit.agents import stt -from livekit.agents.utils import AudioBuffer -from typing import Union, Optional -from .models import DeepgramModels, DeepgramLanguages -from dataclasses import dataclass +import asyncio import dataclasses -import os +import io +import json import logging -import asyncio -import deepgram +import os +import urllib import wave -import io +from dataclasses import dataclass +from typing import Optional, Union + +import aiohttp +from livekit import rtc +from livekit.agents import stt +from livekit.agents.utils import AudioBuffer, merge_frames + +from .models import DeepgramLanguages, DeepgramModels + +STREAM_KEEPALIVE_MSG: str = json.dumps({"type": "KeepAlive"}) +STREAM_CLOSE_MSG: str = json.dumps({"type": "CloseStream"}) # internal @@ -40,12 +47,10 @@ def __init__( min_silence_duration: int = 10, ) -> None: super().__init__(streaming_supported=True) - api_key = api_key or os.environ.get("DEEPGRAM_API_KEY") - if not api_key: + self._api_key = api_key or os.environ.get("DEEPGRAM_API_KEY") + if not self._api_key: raise ValueError("Deepgram API key is required") - dg_opts = deepgram.DeepgramClientOptions(api_key=api_key, url=api_url or "") - self._client = deepgram.DeepgramClient(config=dg_opts) self._config = STTOptions( language=language, detect_language=detect_language, @@ -74,6 +79,17 @@ def _sanitize_options( return config + def _config_to_query(self, config: STTOptions) -> str: + params = { + "model": config.model, + "punctuate": config.punctuate, + "detect_language": config.detect_language, + "smart_format": config.smart_format, + } + if config.language: + params["language"] = config.language + return urllib.parse.urlencode(params).lower() + async def recognize( self, *, @@ -81,9 +97,10 @@ async def recognize( language: Optional[Union[DeepgramLanguages, str]] = None, ) -> stt.SpeechEvent: config = self._sanitize_options(language=language) - + query_params = self._config_to_query(config) + url = f"https://api.deepgram.com/v1/listen?{query_params}" # Deepgram prerecorded API requires WAV/MP3, so we write our PCM into a wav buffer - buffer = agents.utils.merge_frames(buffer) + buffer = merge_frames(buffer) io_buffer = io.BytesIO() with wave.open(io_buffer, "wb") as wav: wav.setnchannels(buffer.num_channels) @@ -91,22 +108,21 @@ async def recognize( wav.setframerate(buffer.sample_rate) wav.writeframes(buffer.data) - source: deepgram.BufferSource = { - "buffer": io_buffer.getvalue(), - } - - dg_opts = deepgram.PrerecordedOptions( - model=config.model, - smart_format=config.smart_format, - language=config.language, - punctuate=config.punctuate, - detect_language=config.detect_language, - ) - - dg_res = await self._client.listen.asyncprerecorded.v("1").transcribe_file( - source, dg_opts - ) - return prerecorded_transcription_to_speech_event(config.language, dg_res) + async with aiohttp.ClientSession( + headers={ + "Authorization": f"Token {self._api_key}", + "Accept": "application/json", + "Content-Type": "audio/wav", + } + ) as session: + async with session.post( + url=url, + data=io_buffer.getvalue(), + ) as res: + json_res = await res.json() + return prerecorded_transcription_to_speech_event( + config.language, json_res + ) def stream( self, @@ -115,26 +131,26 @@ def stream( ) -> "SpeechStream": config = self._sanitize_options(language=language) return SpeechStream( - self._client, config, + api_key=self._api_key, ) class SpeechStream(stt.SpeechStream): def __init__( self, - client: deepgram.DeepgramClient, config: STTOptions, + api_key: str, sample_rate: int = 16000, num_channels: int = 1, ) -> None: super().__init__() - self._client = client self._config = config self._sample_rate = sample_rate self._num_channels = num_channels + self._api_key = api_key - self._queue = asyncio.Queue[rtc.AudioFrame]() + self._queue = asyncio.Queue() self._event_queue = asyncio.Queue[stt.SpeechEvent]() self._closed = False self._main_task = asyncio.create_task(self._run(max_retry=32)) @@ -149,87 +165,122 @@ def push_frame(self, frame: rtc.AudioFrame) -> None: if self._closed: raise ValueError("cannot push frame to closed stream") - self._queue.put_nowait(frame) + self._queue.put_nowait( + frame.remix_and_resample(self._sample_rate, self._num_channels) + ) async def flush(self) -> None: await self._queue.join() async def aclose(self) -> None: - self._main_task.cancel() - try: - await self._main_task - except asyncio.CancelledError: - pass + await self._queue.put(STREAM_CLOSE_MSG) + await self._main_task async def _run(self, max_retry: int) -> None: """Try to connect to Deepgram with exponential backoff and forward frames""" - retry_count = 0 - while True: - try: - self._live = self._client.listen.asynclive.v("1") - - opened = False + async with aiohttp.ClientSession() as session: + retry_count = 0 + ws: Optional[aiohttp.ClientWebSocketResponse] = None + listen_task: Optional[asyncio.Task] = None + keepalive_task: Optional[asyncio.Task] = None + while True: + try: + ws = await self._try_connect(session) + listen_task = asyncio.create_task(self._listen_loop(ws)) + keepalive_task = asyncio.create_task(self._keepalive_loop(ws)) + # break out of the retry loop if we are done + if await self._send_loop(ws): + keepalive_task.cancel() + await asyncio.wait_for(listen_task, timeout=5) + break + except Exception as e: + if retry_count > max_retry and max_retry > 0: + logging.error(f"failed to connect to Deepgram: {e}") + break + + retry_delay = min(retry_count * 5, 5) # max 5s + retry_count += 1 + logging.warning( + f"failed to connect to Deepgram: {e} - retrying in {retry_delay}s" + ) + await asyncio.sleep(retry_delay) - async def on_close(_, **kwargs) -> None: - nonlocal opened - opened = False + self._closed = True - async def on_transcript_received( - _, result: deepgram.LiveResultResponse, **kwargs - ) -> None: - if result.type != "Results": - return + async def _send_loop(self, ws: aiohttp.ClientWebSocketResponse) -> bool: + while not ws.closed: + data = await self._queue.get() + # fire and forget, we don't care if we miss frames in the error case + self._queue.task_done() + + if ws.closed: + raise Exception("websocket closed") + + if isinstance(data, rtc.AudioFrame): + await ws.send_bytes(data.data.tobytes()) + else: + if data == STREAM_CLOSE_MSG: + await ws.send_str(data) + return True + return False + + async def _keepalive_loop(self, ws: aiohttp.ClientWebSocketResponse) -> None: + while not ws.closed: + await ws.send_str(STREAM_KEEPALIVE_MSG) + await asyncio.sleep(5) + + async def _listen_loop(self, ws: aiohttp.ClientWebSocketResponse) -> None: + while not ws.closed: + msg = await ws.receive() + if msg.type in ( + aiohttp.WSMsgType.CLOSED, + aiohttp.WSMsgType.CLOSE, + aiohttp.WSMsgType.CLOSING, + ): + break - speech_event = live_transcription_to_speech_event( - self._config.language, result + try: + if msg.type == aiohttp.WSMsgType.TEXT: + data = json.loads(msg.data) + if data["type"] != "Results": + logging.warning("Skipping non-results message %s", data) + continue + stt_event = live_transcription_to_speech_event( + self._config.language, data ) - self._event_queue.put_nowait(speech_event) + await self._event_queue.put(stt_event) + continue + except Exception as e: + logging.error("Error handling message %s: %s", msg, e) + continue + + logging.warning("Unhandled message %s", msg) + + async def _try_connect( + self, session: aiohttp.ClientSession + ) -> aiohttp.ClientWebSocketResponse: + live_config = { + "model": self._config.model, + "punctuate": self._config.punctuate, + "smart_format": self._config.smart_format, + "interim_results": self._config.interim_results, + "encoding": "linear16", + "sample_rate": self._sample_rate, + "channels": self._num_channels, + "endpointing": str(self._config.endpointing or "10"), + } - self._live.on(deepgram.LiveTranscriptionEvents.Close, on_close) - self._live.on( - deepgram.LiveTranscriptionEvents.Transcript, - on_transcript_received, - ) + if self._config.language: + live_config["language"] = self._config.language - dg_opts = deepgram.LiveOptions( - model=self._config.model, - language=self._config.language, - encoding="linear16", - interim_results=self._config.interim_results, - channels=self._num_channels, - sample_rate=self._sample_rate, - smart_format=self._config.smart_format, - punctuate=self._config.punctuate, - endpointing=self._config.endpointing, - ) - await self._live.start(dg_opts) - opened = True - retry_count = 0 - - while opened: - frame = await self._queue.get() - frame = frame.remix_and_resample( - self._sample_rate, self._num_channels - ) - await self._live.send(frame.data.tobytes()) - self._queue.task_done() + query_params = urllib.parse.urlencode(live_config).lower() - except asyncio.CancelledError: - await asyncio.shield(self._live.finish()) - break - except Exception as e: - if retry_count > max_retry and max_retry > 0: - logging.error(f"failed to connect to Deepgram: {e}") - break - - retry_delay = min(retry_count * 5, 5) # max 5s - retry_count += 1 - logging.warning( - f"failed to connect to Deepgram: {e} - retrying in {retry_delay}s" - ) - await asyncio.sleep(retry_delay) + url = f"wss://api.deepgram.com/v1/listen?{query_params}" + ws = await session.ws_connect( + url, headers={"Authorization": f"Token {self._api_key}"} + ) - self._closed = True + return ws async def __anext__(self) -> stt.SpeechEvent: if self._closed and self._event_queue.empty(): @@ -240,22 +291,23 @@ async def __anext__(self) -> stt.SpeechEvent: def live_transcription_to_speech_event( language: Optional[str], - event: deepgram.LiveResultResponse, + event: dict, ) -> stt.SpeechEvent: - dg_alts = event.channel.alternatives # type: ignore - if not dg_alts: + try: + dg_alts = event["channel"]["alternatives"] + except KeyError: raise ValueError("no alternatives in response") return stt.SpeechEvent( - is_final=event.is_final or False, # could be None? - end_of_speech=event.speech_final or False, + is_final=event["is_final"] or False, # could be None? + end_of_speech=event["speech_final"] or False, alternatives=[ stt.SpeechData( language=language or "", - start_time=(alt.words[0].start if alt.words else 0) or 0, - end_time=(alt.words[-1].end if alt.words else 0) or 0, - confidence=alt.confidence or 0, - text=alt.transcript or "", + start_time=(alt["words"][0]["start"] if alt["words"] else 0) or 0, + end_time=(alt["words"][-1]["end"] if alt["words"] else 0) or 0, + confidence=alt["confidence"] or 0, + text=alt["transcript"] or "", ) for alt in dg_alts ], @@ -264,10 +316,11 @@ def live_transcription_to_speech_event( def prerecorded_transcription_to_speech_event( language: Optional[str], - event: deepgram.PrerecordedResponse, + event: dict, ) -> stt.SpeechEvent: - dg_alts = event.results.channels[0].alternatives # type: ignore - if not dg_alts: + try: + dg_alts = event["results"]["channels"][0]["alternatives"] + except KeyError: raise ValueError("no alternatives in response") return stt.SpeechEvent( @@ -276,11 +329,11 @@ def prerecorded_transcription_to_speech_event( alternatives=[ stt.SpeechData( language=language or "", - start_time=(alt.words[0].start if alt.words else 0) or 0, - end_time=(alt.words[-1].end if alt.words else 0) or 0, - confidence=alt.confidence or 0, + start_time=(alt["words"][0]["start"] if alt["words"] else 0) or 0, + end_time=(alt["words"][-1]["end"] if alt["words"] else 0) or 0, + confidence=alt["confidence"] or 0, # not sure why transcript is Optional inside DG SDK ... - text=alt.transcript or "", + text=alt["transcript"] or "", ) for alt in dg_alts ], diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/version.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/version.py index 5307ff6f8..917f44bea 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/version.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.1.3" +__version__ = "0.1.4" diff --git a/livekit-plugins/livekit-plugins-deepgram/setup.py b/livekit-plugins/livekit-plugins-deepgram/setup.py index d70fa9e8c..4560668a1 100644 --- a/livekit-plugins/livekit-plugins-deepgram/setup.py +++ b/livekit-plugins/livekit-plugins-deepgram/setup.py @@ -49,9 +49,9 @@ packages=setuptools.find_namespace_packages(include=["livekit.*"]), python_requires=">=3.10.0", # deepgram-sdk requires 3.10 install_requires=[ - "deepgram-sdk >= 3.0, < 4.0", "livekit >= 0.8.0", "livekit-agents >= 0.3.0", + "aiohttp >= 3.7.4", ], package_data={}, project_urls={