Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
jayeshp19 committed Dec 26, 2024
1 parent b2b5614 commit c2d26fc
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 73 deletions.
79 changes: 79 additions & 0 deletions livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import os
from typing import Literal

import boto3


def _get_aws_credentials(
api_key: str | None, api_secret: str | None, region: str | None
):
region = region or os.environ.get("AWS_DEFAULT_REGION")
if not region:
raise ValueError(
"AWS_DEFAULT_REGION must be set using the argument or by setting the AWS_DEFAULT_REGION environment variable."
)

# If API key and secret are provided, create a session with them
if api_key and api_secret:
session = boto3.Session(
aws_access_key_id=api_key,
aws_secret_access_key=api_secret,
region_name=region,
)
else:
# Use default credentials from environment or AWS config
session = boto3.Session(region_name=region)

# Validate if session credentials are available
credentials = session.get_credentials()
if not credentials or not credentials.access_key or not credentials.secret_key:
raise ValueError("No valid AWS credentials found.")
return credentials.access_key, credentials.secret_key


TTS_SPEECH_ENGINE = Literal["standard", "neural", "long-form", "generative"]
TTS_LANGUAGE = Literal[
"arb",
"cmn-CN",
"cy-GB",
"da-DK",
"de-DE",
"en-AU",
"en-GB",
"en-GB-WLS",
"en-IN",
"en-US",
"es-ES",
"es-MX",
"es-US",
"fr-CA",
"fr-FR",
"is-IS",
"it-IT",
"ja-JP",
"hi-IN",
"ko-KR",
"nb-NO",
"nl-NL",
"pl-PL",
"pt-BR",
"pt-PT",
"ro-RO",
"ru-RU",
"sv-SE",
"tr-TR",
"en-NZ",
"en-ZA",
"ca-ES",
"de-AT",
"yue-CN",
"ar-AE",
"fi-FI",
"en-IE",
"nl-BE",
"fr-BE",
"cs-CZ",
"de-CH",
]

TTS_OUTPUT_FORMAT = Literal["mp3", "pcm"]
207 changes: 135 additions & 72 deletions livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,28 @@

from __future__ import annotations

import os
import asyncio
from dataclasses import dataclass

import boto3
import aiohttp
from aiobotocore.session import AioSession, get_session
from livekit import rtc
from livekit.agents import tts, utils

from livekit.agents import (
DEFAULT_API_CONNECT_OPTIONS,
APIConnectionError,
APIConnectOptions,
APIStatusError,
APITimeoutError,
tts,
utils,
)

from ._utils import (
TTS_LANGUAGE,
TTS_OUTPUT_FORMAT,
TTS_SPEECH_ENGINE,
_get_aws_credentials,
)
from .log import logger

TTS_SAMPLE_RATE: int = 16000
Expand All @@ -28,11 +42,13 @@

@dataclass
class _TTSOptions:
# https://docs.aws.amazon.com/polly/latest/dg/generative-voices.html
voice: str | None = None
output_format: str | None = None # pcm or mp3
speech_engine: str | None = None # generative, neural, standard
speech_region: str | None = None
# https://docs.aws.amazon.com/polly/latest/dg/API_SynthesizeSpeech.html
voice: str | None
output_format: TTS_OUTPUT_FORMAT
speech_engine: TTS_SPEECH_ENGINE
speech_region: str | None
sample_rate: int
language: TTS_LANGUAGE


class TTS(tts.TTS):
Expand All @@ -41,44 +57,50 @@ def __init__(
*,
voice: str | None = "Ruth",
aws_session: AioSession | None = None,
output_format: str = "pcm",
speech_engine: str = "generative",
language: TTS_LANGUAGE = "en-US",
output_format: TTS_OUTPUT_FORMAT = "pcm",
speech_engine: TTS_SPEECH_ENGINE = "generative",
sample_rate: int = 16000,
speech_region: str = "us-east-1",
speech_key: str | None = None,
speech_secret: str | None = None,
api_key: str | None = None,
api_secret: str | None = None,
) -> None:
"""
Create a new instance of AWS Polly TTS.
``api_key`` and ``api_secret`` must be set to your AWS Access key id and secret access key, either using the argument or by setting the
``AWS_ACCESS_KEY_ID`` and ``AWS_SECRET_ACCESS_KEY`` environmental variables.
See https://docs.aws.amazon.com/polly/latest/dg/API_SynthesizeSpeech.html for more details on the the AWS Polly TTS.
Args:
Voice (TTSModels, optional): Voice ID to use for the synthesis. Defaults to "Ruth".
language (TTS_LANGUAGE, optional): language code for the Synthesize Speech request. This is only necessary if using a bilingual voice, such as Aditi, which can be used for either Indian English (en-IN) or Hindi (hi-IN).
output_format(TTS_OUTPUT_FORMAT, optional): The format in which the returned output will be encoded. Defaults to "pcm".
sample_rate(int, optional): The audio frequency specified in Hz. Defaults to 16000.
speech_engine(TTS_SPEECH_ENGINE, optional): The engine to use for the synthesis. Defaults to "generative".
speech_region(str, optional): The region to use for the synthesis. Defaults to "us-east-1".
api_key(str, optional): AWS access key id.
api_secret(str, optional): AWS secret access key.
"""
super().__init__(
capabilities=tts.TTSCapabilities(
streaming=False,
),
sample_rate=TTS_SAMPLE_RATE,
num_channels=TTS_NUM_CHANNELS,
)
credentials = boto3.Session().get_credentials()

speech_key = (
speech_key or os.environ.get("AWS_ACCESS_KEY_ID") or credentials.access_key
)
if not speech_key:
raise ValueError("AWS_ACCESS_KEY_ID must be set")

speech_secret = (
speech_secret
or os.environ.get("AWS_SECRET_ACCESS_KEY")
or credentials.secret_key
self._api_key, self._api_secret = _get_aws_credentials(
api_key, api_secret, speech_region
)
if not speech_secret:
raise ValueError("AWS_SECRET_ACCESS_KEY must be set")

speech_region = speech_region or os.environ.get("AWS_DEFAULT_REGION")
if not speech_region:
raise ValueError("AWS_DEFAULT_REGION must be set")

self._opts = _TTSOptions(
voice=voice,
output_format=output_format,
speech_engine=speech_engine,
speech_region=speech_region,
language=language,
)
self._session = aws_session

Expand All @@ -91,57 +113,98 @@ def _ensure_session(self) -> AioSession:
def synthesize(
self,
text: str,
*,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
segment_id: str | None = None,
) -> "ChunkedStream":
return ChunkedStream(text, self._opts, self._ensure_session())
return ChunkedStream(
tts=self,
text=text,
conn_options=conn_options,
opts=self._opts,
session=self._ensure_session(),
api_key=self._api_key,
api_secret=self._api_secret,
segment_id=segment_id,
)


class ChunkedStream(tts.ChunkedStream):
def __init__(self, text: str, opts: _TTSOptions, session: AioSession) -> None:
super().__init__()
self._text, self._opts, self._session = text, opts, session
def __init__(
self,
*,
tts: TTS,
text: str,
conn_options: APIConnectOptions,
opts: _TTSOptions,
session: AioSession,
api_key: str,
api_secret: str,
segment_id: str | None = None,
) -> None:
super().__init__(tts=tts, input_text=text, conn_options=conn_options)
self._opts = opts
self._api_key = api_key
self._segment_id = segment_id or utils.shortuuid()
self._api_secret = api_secret
self._session = session

@utils.log_exceptions(logger=logger)
async def _main_task(self) -> None:
async def _run(self) -> None:
request_id = utils.shortuuid()
segment_id = utils.shortuuid()

async with self._session.create_client(
"polly", region_name=self._opts.speech_region
) as client:
response = await client.synthesize_speech(
Text=self._text,
OutputFormat=self._opts.output_format,
Engine=self._opts.speech_engine,
VoiceId=self._opts.voice,
TextType="text",
SampleRate=str(TTS_SAMPLE_RATE),
)
if "AudioStream" in response:
decoder = utils.codecs.Mp3StreamDecoder()
async with response["AudioStream"] as resp:
async for data, _ in resp.content.iter_chunks():
if self._opts.output_format == "mp3":
frames = decoder.decode_chunk(data)
for frame in frames:

try:
async with self._session.create_client(
"polly",
region_name=self._opts.speech_region,
aws_access_key_id=self._api_key,
aws_secret_access_key=self._api_secret,
) as client:
response = await client.synthesize_speech(
Text=self._input_text,
OutputFormat=self._opts.output_format,
Engine=self._opts.speech_engine,
VoiceId=self._opts.voice,
LanguageCode=self._opts.language,
TextType="text",
SampleRate=str(TTS_SAMPLE_RATE),
)
if "AudioStream" in response:
decoder = utils.codecs.Mp3StreamDecoder()
async with response["AudioStream"] as resp:
async for data, _ in resp.content.iter_chunks():
if self._opts.output_format == "mp3":
frames = decoder.decode_chunk(data)
for frame in frames:
self._event_ch.send_nowait(
tts.SynthesizedAudio(
request_id=request_id,
segment_id=self._segment_id,
frame=frame,
)
)
else:
self._event_ch.send_nowait(
tts.SynthesizedAudio(
request_id=request_id,
segment_id=segment_id,
frame=frame,
segment_id=self._segment_id,
frame=rtc.AudioFrame(
data=data,
sample_rate=TTS_SAMPLE_RATE,
num_channels=1,
samples_per_channel=len(data)
// 2, # 16-bit
),
)
)
else:
self._event_ch.send_nowait(
tts.SynthesizedAudio(
request_id=request_id,
segment_id=segment_id,
frame=rtc.AudioFrame(
data=data,
sample_rate=TTS_SAMPLE_RATE,
num_channels=1,
samples_per_channel=len(data) // 2, # 16-bit
),
)
)
else:
logger.error("polly tts failed to synthesizes speech")
except asyncio.TimeoutError as e:
raise APITimeoutError() from e
except aiohttp.ClientResponseError as e:
raise APIStatusError(
message=e.message,
status_code=e.status,
request_id=request_id,
body=None,
) from e
except Exception as e:
raise APIConnectionError() from e
3 changes: 2 additions & 1 deletion tests/test_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from livekit import agents
from livekit.agents import APIConnectionError, tokenize, tts
from livekit.agents.utils import AudioBuffer, merge_frames
from livekit.plugins import aws, (
from livekit.plugins import (
aws,
azure,
cartesia,
deepgram,
Expand Down

0 comments on commit c2d26fc

Please sign in to comment.