Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(tts): Adding support for pflow model input. #63

Merged
merged 2 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 46 additions & 11 deletions riva/client/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import riva.client.proto.riva_tts_pb2_grpc as rtts_srv
from riva.client import Auth
from riva.client.proto.riva_audio_pb2 import AudioEncoding

import wave

class SpeechSynthesisService:
"""
Expand All @@ -34,20 +34,27 @@ def synthesize(
language_code: str = 'en-US',
encoding: AudioEncoding = AudioEncoding.LINEAR_PCM,
sample_rate_hz: int = 44100,
audio_prompt_file: Optional[str] = None,
audio_prompt_encoding: AudioEncoding = AudioEncoding.LINEAR_PCM,
quality: int = 20,
future: bool = False,
) -> Union[rtts.SynthesizeSpeechResponse, _MultiThreadedRendezvous]:
"""
Synthesizes an entire audio for text :param:`text`.

Args:
text (:obj:`str`): an input text.
voice_name (:obj:`str`, `optional`): a name of the voice, e.g. ``"English-US-Female-1"``. You may find
text (:obj:`str`): An input text.
voice_name (:obj:`str`, `optional`): A name of the voice, e.g. ``"English-US-Female-1"``. You may find
available voices in server logs or in server model directory. If this parameter is :obj:`None`, then
a server will select the first available model with correct :param:`language_code` value.
language_code (:obj:`str`): a language to use.
encoding (:obj:`AudioEncoding`): an output audio encoding, e.g. ``AudioEncoding.LINEAR_PCM``.
sample_rate_hz (:obj:`int`): number of frames per second in output audio.
future (:obj:`bool`, defaults to :obj:`False`): whether to return an async result instead of usual
encoding (:obj:`AudioEncoding`): An output audio encoding, e.g. ``AudioEncoding.LINEAR_PCM``.
sample_rate_hz (:obj:`int`): Number of frames per second in output audio.
audio_prompt_file (:obj:`str`): An audio prompt file location for zero shot model.
audio_prompt_encoding: (:obj:`AudioEncoding`): Encoding of audio prompt file, e.g. ``AudioEncoding.LINEAR_PCM``.
quality: (:obj:`int`): This defines the number of times decoder is run. Higher number improves quality of generated
audio but also takes longer to generate the audio. Ranges between 1-40.
future (:obj:`bool`, defaults to :obj:`False`): Whether to return an async result instead of usual
response. You can get a response by calling ``result()`` method of the future object.

Returns:
Expand All @@ -64,6 +71,16 @@ def synthesize(
)
if voice_name is not None:
req.voice_name = voice_name
if audio_prompt_file is not None:
with wave.open(str(audio_prompt_file), 'rb') as wf:
rate = wf.getframerate()
req.zero_shot_data.sample_rate = rate
with audio_prompt_file.open('rb') as wav_f:
audio_data = wav_f.read()
req.zero_shot_data.audio_prompt = audio_data
req.zero_shot_data.encoding = audio_prompt_encoding
req.zero_shot_data.quality = quality

func = self.stub.Synthesize.future if future else self.stub.Synthesize
return func(req, metadata=self.auth.get_auth_metadata())

Expand All @@ -74,19 +91,26 @@ def synthesize_online(
language_code: str = 'en-US',
encoding: AudioEncoding = AudioEncoding.LINEAR_PCM,
sample_rate_hz: int = 44100,
audio_prompt_file: Optional[str] = None,
audio_prompt_encoding: AudioEncoding = AudioEncoding.LINEAR_PCM,
quality: int = 20,
) -> Generator[rtts.SynthesizeSpeechResponse, None, None]:
"""
Synthesizes and yields output audio chunks for text :param:`text` as the chunks
becoming available.

Args:
text (:obj:`str`): an input text.
voice_name (:obj:`str`, `optional`): a name of the voice, e.g. ``"English-US-Female-1"``. You may find
text (:obj:`str`): An input text.
voice_name (:obj:`str`, `optional`): A name of the voice, e.g. ``"English-US-Female-1"``. You may find
available voices in server logs or in server model directory. If this parameter is :obj:`None`, then
a server will select the first available model with correct :param:`language_code` value.
language_code (:obj:`str`): a language to use.
encoding (:obj:`AudioEncoding`): an output audio encoding, e.g. ``AudioEncoding.LINEAR_PCM``.
sample_rate_hz (:obj:`int`): number of frames per second in output audio.
language_code (:obj:`str`): A language to use.
encoding (:obj:`AudioEncoding`): An output audio encoding, e.g. ``AudioEncoding.LINEAR_PCM``.
sample_rate_hz (:obj:`int`): Number of frames per second in output audio.
audio_prompt_file (:obj:`str`): An audio prompt file location for zero shot model.
audio_prompt_encoding: (:obj:`AudioEncoding`): Encoding of audio prompt file, e.g. ``AudioEncoding.LINEAR_PCM``.
quality: (:obj:`int`): This defines the number of times decoder is run. Higher number improves quality of generated
audio but also takes longer to generate the audio. Ranges between 1-40.

Yields:
:obj:`riva.client.proto.riva_tts_pb2.SynthesizeSpeechResponse`: a response with output. You may find
Expand All @@ -103,4 +127,15 @@ def synthesize_online(
)
if voice_name is not None:
req.voice_name = voice_name

if audio_prompt_file is not None:
with wave.open(str(audio_prompt_file), 'rb') as wf:
rate = wf.getframerate()
req.zero_shot_data.sample_rate = rate
with audio_prompt_file.open('rb') as wav_f:
audio_data = wav_f.read()
req.zero_shot_data.audio_prompt = audio_data
req.zero_shot_data.encoding = audio_prompt_encoding
req.zero_shot_data.quality = quality

return self.stub.SynthesizeOnline(req, metadata=self.auth.get_auth_metadata())