Skip to content

Commit

Permalink
feat: add support to change rate and pitch (#14)
Browse files Browse the repository at this point in the history
* feat: add minimal workable code

* feat: add removal of html tags from text

* feat: add pitch and rate interpolation

* feat: add errors of incorrect pitch and rate

* feat(app): add pitch and rate limit

* style: improve style
  • Loading branch information
MrPandir authored Mar 16, 2024
1 parent 39cc8a6 commit 18c895d
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 8 deletions.
13 changes: 11 additions & 2 deletions app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,16 @@ def generate(
sample_rate: Annotated[
int, Parameter(examples=sample_rate_examples, default=48_000)
],
pitch: Annotated[int, Parameter(ge=0, le=100, default=50)],
rate: Annotated[int, Parameter(ge=0, le=100, default=50)],
) -> Response:
if len(text) > text_length_limit:
raise TextTooLongHTTPException(
{"text": text, "length": len(text), "max_length": text_length_limit}
)

try:
audio = tts.generate(text, speaker, sample_rate)
audio = tts.generate(text, speaker, sample_rate, pitch, rate)
except NotFoundModelException:
raise NotFoundSpeakerHTTPException({"speaker": speaker})
except NotCorrectTextException:
Expand All @@ -57,6 +59,9 @@ def generate(
raise InvalidSampleRateHTTPException(
{"sample_rate": sample_rate, "valid_sample_rates": tts.VALID_SAMPLE_RATES}
)
except (InvalidPitchException, InvalidRateException):
# This will never happen because litestar ensures compliance with the parameters `ge` and `le`.
pass
else:
return Response(audio, media_type="audio/wav")

Expand All @@ -65,12 +70,16 @@ def generate(
async def speakers() -> dict[str, list[str]]:
return tts.speakers


@get(["/", "/docs"], include_in_schema=False)
async def docs() -> Redirect:
return Redirect("/schema")


app = Litestar(
[generate, speakers, docs],
openapi_config=OpenAPIConfig(title="Silero TTS API", version="1.0.0", root_schema_site="swagger"),
openapi_config=OpenAPIConfig(
title="Silero TTS API", version="1.0.0", root_schema_site="swagger"
),
cors_config=CORSConfig(),
)
10 changes: 10 additions & 0 deletions tts/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,13 @@ class InvalidSampleRateException(Exception):
def __init__(self, sample_rate: int) -> None:
self.sample_rate = sample_rate
super().__init__(f"Invalid sample rate {sample_rate}. Supported sample rates are 8 000, 24 000, and 48 000.")

class InvalidPitchException(Exception):
def __init__(self, pitch: int) -> None:
self.pitch = pitch
super().__init__(f"Invalid pitch {pitch}. Pitch should be in range from 0 to 100.")

class InvalidRateException(Exception):
def __init__(self, rate: int) -> None:
self.rate = rate
super().__init__(f"Invalid rate {rate}. Rate should be in range from 0 to 100.")
54 changes: 48 additions & 6 deletions tts/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,27 @@ def __init__(self):
for model_path in Path("models").glob("*.pt"):
self._load_model(model_path)

def generate(self, text: str, speaker: str, sample_rate: int) -> bytes:
def generate(
self, text: str, speaker: str, sample_rate: int, pitch: int, rate: int
) -> bytes:
model = self.model_by_speaker.get(speaker)
if not model:
raise NotFoundModelException(speaker)
if sample_rate not in self.VALID_SAMPLE_RATES:
raise InvalidSampleRateException(sample_rate)

if not 0 <= pitch <= 100:
raise InvalidPitchException(pitch)
if not 0 <= rate <= 100:
raise InvalidRateException(rate)

pitch = self._interpolate_pitch(pitch)
rate = self._interpolate_rate(rate)

text = self._delete_dashes(text)
tensor = self._generate_audio(model, text, speaker, sample_rate)
text = self._delete_html_brackets(text)

tensor = self._generate_audio(model, text, speaker, sample_rate, pitch, rate)
return self._convert_to_wav(tensor, sample_rate)

def _load_model(self, model_path: Path):
Expand All @@ -64,17 +76,46 @@ def _load_speakers(self, model: "TTSModelMultiAcc_v3", language: str):
self.speakers[language] = model.speakers
for speaker in model.speakers:
self.model_by_speaker[speaker] = model

def _delete_dashes(self, text: str) -> str:
# This fixes the problem:
# https://github.com/twirapp/silero-tts-api-server/issues/8
return text.replace("-", "").replace("‑", "")

def _delete_html_brackets(self, text: str) -> str:
# Safeguarding against pitch and rate modifications with HTML tags in text.
# And also prevents raising the error of generation of audio `ValueError`, if there is html tags.
return text.replace("<", "").replace(">", "")

def _interpolate_pitch(self, pitch: int) -> int:
# One interesting feature of the models is that when a pitch of -100 is input,
# it transforms to `1.0 + (-100 / 100) = 0`, making the sound equivalent to generating `1.0 + (0 / 100) = 1`.
# This makes the voice the same for 0 and 1
if pitch == 0:
return -101

SCALE_FACTOR = 2
OFFSET = -100
return pitch * SCALE_FACTOR + OFFSET

def _interpolate_rate(self, rate: int) -> int:
OFFSET = 50
return rate + OFFSET

def _generate_audio(
self, model: "TTSModelMultiAcc_v3", text: str, speaker: str, sample_rate: int
self,
model: "TTSModelMultiAcc_v3",
text: str,
speaker: str,
sample_rate: int,
pitch: int,
rate: int,
) -> torch.Tensor:
ssml_text = f"<speak><prosody pitch='+{pitch}%' rate='{rate}%'>{text}</prosody></speak>"
try:
return model.apply_tts(text=text, speaker=speaker, sample_rate=sample_rate)
return model.apply_tts(
ssml_text=ssml_text, speaker=speaker, sample_rate=sample_rate
)
except ValueError:
raise NotCorrectTextException(text)
except Exception as error:
Expand All @@ -84,7 +125,7 @@ def _generate_audio(

def _convert_to_wav(self, tensor: torch.Tensor, sample_rate: int) -> bytes:
audio = self._normalize_audio(tensor)
with BytesIO() as buffer, wave.open(buffer, 'wb') as wav:
with BytesIO() as buffer, wave.open(buffer, "wb") as wav:
wav.setnchannels(1) # mono
wav.setsampwidth(2) # quality is 16 bit. Do not change
wav.setframerate(sample_rate)
Expand All @@ -97,4 +138,5 @@ def _normalize_audio(self, tensor: torch.Tensor):
audio: np.ndarray = tensor.numpy() * MAX_INT16
return audio.astype(np.int16)


tts = TTS()

0 comments on commit 18c895d

Please sign in to comment.