Skip to content

Commit

Permalink
Turn Based Coqui and Agent Improvements (livekit#172)
Browse files Browse the repository at this point in the history
* Add async synthesize, xtts, and prompt to coqui TB

* add speechrecognition and aiohttp dependencies

* add optional memory arg to turn-based ChatGPTAgent

* fix mypy issue

* pr feedback

* fix py3.8 typing issue

* another py3.8 fix
  • Loading branch information
zaptrem authored Jun 1, 2023
1 parent 49acbdb commit c35f9e5
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 13 deletions.
20 changes: 19 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ google-cloud-speech = {version = "^2.19.0", optional = true}
redis = {version = "^4.5.4", optional = true}
twilio = {version = "^8.1.0", optional = true}
nylas = {version = "^5.14.0", optional = true}
speechrecognition = "^3.10.0"
aiohttp = "^3.8.4"


[tool.poetry.group.lint.dependencies]
Expand Down
5 changes: 3 additions & 2 deletions vocode/turn_based/agent/chat_gpt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
model_name: str = "gpt-3.5-turbo",
temperature: float = 0.7,
max_tokens: int = 100,
memory: Optional[ConversationBufferMemory] = None,
):
super().__init__(initial_message=initial_message)
openai.api_key = getenv("OPENAI_API_KEY", api_key)
Expand All @@ -35,7 +36,7 @@ def __init__(
HumanMessagePromptTemplate.from_template("{input}"),
]
)
self.memory = ConversationBufferMemory(return_messages=True)
self.memory = memory if memory else ConversationBufferMemory(return_messages=True)
if initial_message:
self.memory.chat_memory.add_ai_message(initial_message)
self.llm = ChatOpenAI( # type: ignore
Expand All @@ -48,4 +49,4 @@ def __init__(
)

def respond(self, human_input: str):
return self.conversation.predict(input=human_input)
return self.conversation.predict(input=human_input)
2 changes: 2 additions & 0 deletions vocode/turn_based/synthesizer/base_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@
class BaseSynthesizer:
def synthesize(self, text) -> AudioSegment:
raise NotImplementedError
async def async_synthesize(self, text) -> AudioSegment:
raise NotImplementedError
154 changes: 144 additions & 10 deletions vocode/turn_based/synthesizer/coqui_synthesizer.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,163 @@
import io
from typing import Optional
import re
import typing
from typing import Optional, List
from pydub import AudioSegment
import requests
from vocode import getenv
from vocode.turn_based.synthesizer.base_synthesizer import BaseSynthesizer
import aiohttp
import asyncio

COQUI_BASE_URL = "https://app.coqui.ai/api/v2/"
COQUI_BASE_URL = "https://app.coqui.ai/api/v2/samples"
DEFAULT_SPEAKER_ID = "d2bd7ccb-1b65-4005-9578-32c4e02d8ddf"
MAX_TEXT_LENGTH = 250 # The maximum length of text that can be synthesized at once


class CoquiSynthesizer(BaseSynthesizer):
def __init__(self, voice_id: Optional[str] = None, api_key: Optional[str] = None):
def __init__(
self,
voice_id: Optional[str] = None,
voice_prompt: Optional[str] = None,
use_xtts: bool = False,
api_key: Optional[str] = None,
):
self.voice_id = voice_id or DEFAULT_SPEAKER_ID
self.voice_prompt = voice_prompt
self.use_xtts = use_xtts
self.api_key = getenv("COQUI_API_KEY", api_key)

def synthesize(self, text: str) -> AudioSegment:
url = COQUI_BASE_URL + "samples"
headers = {"Authorization": f"Bearer {self.api_key}"}
body = {
"text": text,
"speaker_id": self.voice_id,
"name": "unnamed",
}
text_chunks = self.split_text(text)
audio_chunks = [self.synthesize_chunk(chunk) for chunk in text_chunks]
return sum(audio_chunks) # type: ignore

def synthesize_chunk(self, text: str) -> AudioSegment:
url, headers, body = self.get_request(text)

# Get the sample
response = requests.post(url, headers=headers, json=body)
assert response.ok, response.text
sample = response.json()
response = requests.get(sample["audio_url"])
return AudioSegment.from_wav(io.BytesIO(response.content)) # type: ignore


def split_text(self, string):
# Base case: if the string is less than or equal to MAX_TEXT_LENGTH characters, return it as a single element array
if len(string) <= MAX_TEXT_LENGTH:
return [string.strip()]

# Recursive case: find the index of the last sentence ender in the first MAX_TEXT_LENGTH characters of the string
sentence_enders = [".", "!", "?"]
index = -1
for ender in sentence_enders:
i = string[:MAX_TEXT_LENGTH].rfind(ender)
if i > index:
index = i

# If there is a sentence ender, split the string at that index plus one and strip any spaces from both parts
if index != -1:
first_part = string[:index + 1].strip()
second_part = string[index + 1:].strip()

# If there is no sentence ender, find the index of the last comma in the first MAX_TEXT_LENGTH characters of the string
else:
index = string[:MAX_TEXT_LENGTH].rfind(",")
# If there is a comma, split the string at that index plus one and strip any spaces from both parts
if index != -1:
first_part = string[:index + 1].strip()
second_part = string[index + 1:].strip()

# If there is no comma, find the index of the last space in the first MAX_TEXT_LENGTH characters of the string
else:
index = string[:MAX_TEXT_LENGTH].rfind(" ")
# If there is a space, split the string at that index and strip any spaces from both parts
if index != -1:
first_part = string[:index].strip()
second_part = string[index:].strip()

# If there is no space, split the string at MAX_TEXT_LENGTH characters and strip any spaces from both parts
else:
first_part = string[:MAX_TEXT_LENGTH].strip()
second_part = string[MAX_TEXT_LENGTH:].strip()

# Append the first part to the result array
result = [first_part]

# Call the function recursively on the remaining part of the string and extend the result array with it, unless it is empty
if second_part != "":
result.extend(self.split_text(second_part))

# Return the result array
return result




async def async_synthesize(self, text: str) -> AudioSegment:
# This method is similar to the synthesize method, but it uses async IO to synthesize each chunk in parallel

# Split the text into chunks of less than MAX_TEXT_LENGTH characters
text_chunks = self.split_text(text)

# Create a list of tasks for each chunk using asyncio.create_task()
tasks = [
asyncio.create_task(self.async_synthesize_chunk(chunk))
for chunk in text_chunks
]

# Wait for all tasks to complete using asyncio.gather()
audio_chunks = await asyncio.gather(*tasks)

# Concatenate and return the results
return sum(audio_chunks) # type: ignore

async def async_synthesize_chunk(self, text: str) -> AudioSegment:
url, headers, body = self.get_request(text)

# Create an aiohttp session and post the request asynchronously using await
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=body) as response:
assert response.status == 201, (
await response.text() + url + str(headers) + str(body)
)
sample = await response.json()
audio_url = sample["audio_url"]

# Get the audio data asynchronously using await
async with session.get(audio_url) as response:
assert response.status == 200, "Coqui audio download failed"
audio_data = await response.read()

# Return an AudioSegment object from the audio data
return AudioSegment.from_wav(io.BytesIO(audio_data)) # type: ignore

def get_request(self, text: str) -> typing.Tuple[str, typing.Dict[str, str], typing.Dict[str, object]]:
url = COQUI_BASE_URL
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
}
body = {
"text": text,
"speed": 1,
}

if self.use_xtts:
url += "/xtts/"
# If we have a voice prompt, use that instead of the voice ID
if self.voice_prompt is not None:
url += "render-from-prompt/"
body["prompt"] = self.voice_prompt
else:
url += "render/"
body["voice_id"] = self.voice_id
else:
if self.voice_prompt is not None:
url += "/from-prompt/"
body["prompt"] = self.voice_prompt
else:
body["voice_id"] = self.voice_id
return url, headers, body

0 comments on commit c35f9e5

Please sign in to comment.