Skip to content

Commit

Permalink
Support specifying provider for python examples (#244)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Aug 9, 2023
1 parent 6235cb9 commit aeb112d
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 1 deletion.
9 changes: 8 additions & 1 deletion python-api-examples/online-decode-files.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ def get_args():
""",
)

parser.add_argument(
"--provider",
type=str,
default="cpu",
help="Valid values: cpu, cuda, coreml",
)

parser.add_argument(
"--bpe-model",
type=str,
Expand Down Expand Up @@ -204,6 +211,7 @@ def main():
decoder=args.decoder,
joiner=args.joiner,
num_threads=args.num_threads,
provider=args.provider,
sample_rate=16000,
feature_dim=80,
decoding_method=args.decoding_method,
Expand All @@ -220,7 +228,6 @@ def main():
print(f"Contexts list: {contexts}")
contexts_list = encode_contexts(args, contexts)


streams = []
total_duration = 0
for wave_filename in args.sound_files:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ def get_args():
help="Valid values are greedy_search and modified_beam_search",
)

parser.add_argument(
"--provider",
type=str,
default="cpu",
help="Valid values: cpu, cuda, coreml",
)

return parser.parse_args()


Expand All @@ -97,6 +104,7 @@ def create_recognizer():
rule2_min_trailing_silence=1.2,
rule3_min_utterance_length=300, # it essentially disables this rule
decoding_method=args.decoding_method,
provider=args.provider,
)
return recognizer

Expand Down
11 changes: 11 additions & 0 deletions python-api-examples/speech-recognition-from-microphone.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ def get_args():
""",
)

parser.add_argument(
"--provider",
type=str,
default="cpu",
help="Valid values: cpu, cuda, coreml",
)

parser.add_argument(
"--bpe-model",
type=str,
Expand Down Expand Up @@ -148,10 +155,12 @@ def create_recognizer():
feature_dim=80,
decoding_method=args.decoding_method,
max_active_paths=args.max_active_paths,
provider=args.provider,
context_score=args.context_score,
)
return recognizer


def encode_contexts(args, contexts: List[str]) -> List[List[int]]:
sp = None
if "bpe" in args.modeling_unit:
Expand All @@ -172,6 +181,7 @@ def encode_contexts(args, contexts: List[str]) -> List[List[int]]:
tokens_table=tokens,
)


def main():
args = get_args()

Expand Down Expand Up @@ -205,6 +215,7 @@ def main():
last_result = result
print("\r{}".format(result), end="", flush=True)


if __name__ == "__main__":
devices = sd.query_devices()
print(devices)
Expand Down
8 changes: 8 additions & 0 deletions python-api-examples/streaming_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@ def add_model_args(parser: argparse.ArgumentParser):
help="Feature dimension of the model",
)

parser.add_argument(
"--provider",
type=str,
default="cpu",
help="Valid values: cpu, cuda, coreml",
)


def add_decoding_args(parser: argparse.ArgumentParser):
parser.add_argument(
Expand Down Expand Up @@ -301,6 +308,7 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
rule1_min_trailing_silence=args.rule1_min_trailing_silence,
rule2_min_trailing_silence=args.rule2_min_trailing_silence,
rule3_min_utterance_length=args.rule3_min_utterance_length,
provider=args.provider,
)

return recognizer
Expand Down

1 comment on commit aeb112d

@A-Raafat
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @csukuangfj
non_streaming_server.py is missing provider

Please sign in to comment.