Skip to content

Commit

Permalink
feat(tts): Adding support for zero shot model (#63)
Browse files Browse the repository at this point in the history
* feat(tts): Adding support for pflow model input.

* chore(tts): Updated SHA for common github repo
  • Loading branch information
atomer-nvidia authored Feb 28, 2024
1 parent 153ebf0 commit 8814150
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 12 deletions.
57 changes: 46 additions & 11 deletions riva/client/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import riva.client.proto.riva_tts_pb2_grpc as rtts_srv
from riva.client import Auth
from riva.client.proto.riva_audio_pb2 import AudioEncoding

import wave

class SpeechSynthesisService:
"""
Expand All @@ -34,20 +34,27 @@ def synthesize(
language_code: str = 'en-US',
encoding: AudioEncoding = AudioEncoding.LINEAR_PCM,
sample_rate_hz: int = 44100,
audio_prompt_file: Optional[str] = None,
audio_prompt_encoding: AudioEncoding = AudioEncoding.LINEAR_PCM,
quality: int = 20,
future: bool = False,
) -> Union[rtts.SynthesizeSpeechResponse, _MultiThreadedRendezvous]:
"""
Synthesizes an entire audio for text :param:`text`.
Args:
text (:obj:`str`): an input text.
voice_name (:obj:`str`, `optional`): a name of the voice, e.g. ``"English-US-Female-1"``. You may find
text (:obj:`str`): An input text.
voice_name (:obj:`str`, `optional`): A name of the voice, e.g. ``"English-US-Female-1"``. You may find
available voices in server logs or in server model directory. If this parameter is :obj:`None`, then
a server will select the first available model with correct :param:`language_code` value.
language_code (:obj:`str`): a language to use.
encoding (:obj:`AudioEncoding`): an output audio encoding, e.g. ``AudioEncoding.LINEAR_PCM``.
sample_rate_hz (:obj:`int`): number of frames per second in output audio.
future (:obj:`bool`, defaults to :obj:`False`): whether to return an async result instead of usual
encoding (:obj:`AudioEncoding`): An output audio encoding, e.g. ``AudioEncoding.LINEAR_PCM``.
sample_rate_hz (:obj:`int`): Number of frames per second in output audio.
audio_prompt_file (:obj:`str`): An audio prompt file location for zero shot model.
audio_prompt_encoding: (:obj:`AudioEncoding`): Encoding of audio prompt file, e.g. ``AudioEncoding.LINEAR_PCM``.
quality: (:obj:`int`): This defines the number of times decoder is run. Higher number improves quality of generated
audio but also takes longer to generate the audio. Ranges between 1-40.
future (:obj:`bool`, defaults to :obj:`False`): Whether to return an async result instead of usual
response. You can get a response by calling ``result()`` method of the future object.
Returns:
Expand All @@ -64,6 +71,16 @@ def synthesize(
)
if voice_name is not None:
req.voice_name = voice_name
if audio_prompt_file is not None:
with wave.open(str(audio_prompt_file), 'rb') as wf:
rate = wf.getframerate()
req.zero_shot_data.sample_rate = rate
with audio_prompt_file.open('rb') as wav_f:
audio_data = wav_f.read()
req.zero_shot_data.audio_prompt = audio_data
req.zero_shot_data.encoding = audio_prompt_encoding
req.zero_shot_data.quality = quality

func = self.stub.Synthesize.future if future else self.stub.Synthesize
return func(req, metadata=self.auth.get_auth_metadata())

Expand All @@ -74,19 +91,26 @@ def synthesize_online(
language_code: str = 'en-US',
encoding: AudioEncoding = AudioEncoding.LINEAR_PCM,
sample_rate_hz: int = 44100,
audio_prompt_file: Optional[str] = None,
audio_prompt_encoding: AudioEncoding = AudioEncoding.LINEAR_PCM,
quality: int = 20,
) -> Generator[rtts.SynthesizeSpeechResponse, None, None]:
"""
Synthesizes and yields output audio chunks for text :param:`text` as the chunks
becoming available.
Args:
text (:obj:`str`): an input text.
voice_name (:obj:`str`, `optional`): a name of the voice, e.g. ``"English-US-Female-1"``. You may find
text (:obj:`str`): An input text.
voice_name (:obj:`str`, `optional`): A name of the voice, e.g. ``"English-US-Female-1"``. You may find
available voices in server logs or in server model directory. If this parameter is :obj:`None`, then
a server will select the first available model with correct :param:`language_code` value.
language_code (:obj:`str`): a language to use.
encoding (:obj:`AudioEncoding`): an output audio encoding, e.g. ``AudioEncoding.LINEAR_PCM``.
sample_rate_hz (:obj:`int`): number of frames per second in output audio.
language_code (:obj:`str`): A language to use.
encoding (:obj:`AudioEncoding`): An output audio encoding, e.g. ``AudioEncoding.LINEAR_PCM``.
sample_rate_hz (:obj:`int`): Number of frames per second in output audio.
audio_prompt_file (:obj:`str`): An audio prompt file location for zero shot model.
audio_prompt_encoding: (:obj:`AudioEncoding`): Encoding of audio prompt file, e.g. ``AudioEncoding.LINEAR_PCM``.
quality: (:obj:`int`): This defines the number of times decoder is run. Higher number improves quality of generated
audio but also takes longer to generate the audio. Ranges between 1-40.
Yields:
:obj:`riva.client.proto.riva_tts_pb2.SynthesizeSpeechResponse`: a response with output. You may find
Expand All @@ -103,4 +127,15 @@ def synthesize_online(
)
if voice_name is not None:
req.voice_name = voice_name

if audio_prompt_file is not None:
with wave.open(str(audio_prompt_file), 'rb') as wf:
rate = wf.getframerate()
req.zero_shot_data.sample_rate = rate
with audio_prompt_file.open('rb') as wav_f:
audio_data = wav_f.read()
req.zero_shot_data.audio_prompt = audio_data
req.zero_shot_data.encoding = audio_prompt_encoding
req.zero_shot_data.quality = quality

return self.stub.SynthesizeOnline(req, metadata=self.auth.get_auth_metadata())

0 comments on commit 8814150

Please sign in to comment.