Skip to content

Commit

Permalink
minor fixes to asr and tts clients
Browse files Browse the repository at this point in the history
  • Loading branch information
virajkarandikar committed Jul 22, 2024
1 parent a0a8c66 commit 9b53842
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 14 deletions.
4 changes: 2 additions & 2 deletions riva/client/argparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def add_asr_config_argparse_parameters(
)
parser.add_argument("--language-code", default="en-US", help="Language code of the model to be used.")
parser.add_argument("--model-name", default="", help="Model name to be used.")
parser.add_argument("--boosted-lm-words", action='append', help="Words to boost when decoding.")
parser.add_argument("--boosted-lm-words", action='append', help="Words to boost when decoding. Can be used multiple times to boost multiple words.")
parser.add_argument(
"--boosted-lm-score", type=float, default=4.0, help="Value by which to boost words when decoding."
"--boosted-lm-score", type=float, default=4.0, help="Recommended range for the boost score is 20 to 100. The higher the boost score, the more biased the ASR engine is towards this word."
)
parser.add_argument(
"--speaker-diarization",
Expand Down
29 changes: 22 additions & 7 deletions scripts/asr/transcribe_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ def parse_args() -> argparse.Namespace:
"`--play-audio` or `--output-device`.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--input-file", help="A path to a local file to stream.")
parser.add_argument("--list-devices", action="store_true", help="List output devices indices")
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--input-file", help="A path to a local file to stream.")
group.add_argument("--list-models", action="store_true", help="List available models.")
group.add_argument("--list-devices", action="store_true", help="List output devices indices")

parser.add_argument(
"--show-intermediate", action="store_true", help="Show intermediate transcripts as they are available."
)
Expand Down Expand Up @@ -51,11 +54,6 @@ def parse_args() -> argparse.Namespace:
parser = add_connection_argparse_parameters(parser)
parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True)
args = parser.parse_args()
if not args.list_devices and args.input_file is None:
parser.error(
"You have to provide at least one of parameters `--input-file` and `--list-devices` whereas both "
"parameters are missing."
)
if args.play_audio or args.output_device is not None or args.list_devices:
import riva.client.audio_io
return args
Expand All @@ -68,6 +66,23 @@ def main() -> None:
return
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
asr_service = riva.client.ASRService(auth)

if args.list_models:
asr_models = dict()
config_response = asr_service.stub.GetRivaSpeechRecognitionConfig(riva.client.proto.riva_asr_pb2.RivaSpeechRecognitionConfigRequest())
for model_config in config_response.model_config:
if model_config.parameters["streaming"] and model_config.parameters["type"]:
language_code = model_config.parameters['language_code']
if language_code in asr_models:
asr_models[language_code]["models"].append(model_config.model_name)
else:
asr_models[language_code] = {"models": [model_config.model_name]}

print("Available ASR models")
asr_models = dict(sorted(asr_models.items()))
print(asr_models)
return

config = riva.client.StreamingRecognitionConfig(
config=riva.client.RecognitionConfig(
language_code=args.language_code,
Expand Down
12 changes: 7 additions & 5 deletions scripts/tts/talk.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,18 @@

def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="A speech synthesis via Riva AI Services. You HAVE TO provide at least one of arguments "
"`--output`, `--play-audio`, `--list-devices`, `--output-device`.",
description="Speech synthesis via Riva AI Services",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--text", type=str, help="Text input to synthesize.")
group.add_argument("--list-devices", action="store_true", help="List output audio devices indices.")
group.add_argument("--list-voices", action="store_true", help="List available voices.")
parser.add_argument(
"--voice",
help="A voice name to use. If this parameter is missing, then the server will try a first available model "
"based on parameter `--language-code`.",
)
parser.add_argument("--text", type=str, required=False, help="Text input to synthesize.")
parser.add_argument(
"--audio_prompt_file",
type=Path,
Expand All @@ -35,8 +37,6 @@ def parse_args() -> argparse.Namespace:
help="Whether to play input audio simultaneously with transcribing. If `--output-device` is not provided, "
"then the default output audio device will be used.",
)
parser.add_argument("--list-devices", action="store_true", help="List output audio devices indices.")
parser.add_argument("--list-voices", action="store_true", help="List available voices.")
parser.add_argument("--output-device", type=int, help="Output device to use.")
parser.add_argument("--language-code", default='en-US', help="A language of input text.")
parser.add_argument(
Expand All @@ -62,6 +62,7 @@ def main() -> None:
args = parse_args()
if args.list_devices:
riva.client.audio_io.list_output_devices()
return

auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
service = riva.client.SpeechSynthesisService(auth)
Expand All @@ -87,6 +88,7 @@ def main() -> None:

tts_models = dict(sorted(tts_models.items()))
print(json.dumps(tts_models, indent=4))
return

if not args.text:
print("No input text provided")
Expand Down

0 comments on commit 9b53842

Please sign in to comment.