Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exposing the 'stop_historu_eou_th' parameter #83

Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[submodule "common"]
path = common
url = https://github.com/nvidia-riva/common.git
branch = main
url = https://github.com/sarane22/common.git
branch = endpointing_stop_eou_threshold_param
rmittal-github marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion common
16 changes: 11 additions & 5 deletions riva/client/argparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def add_asr_config_argparse_parameters(
"--start-history",
default=-1,
type=int,
help="Value to detect and initiate start of speech utterance",
help="Value (in milliseconds) to detect and initiate start of speech utterance",
)
parser.add_argument(
"--start-threshold",
Expand All @@ -64,19 +64,25 @@ def add_asr_config_argparse_parameters(
"--stop-history",
default=-1,
type=int,
help="Value to reset the endpoint detection history",
help="Value (in milliseconds) to detect end of utterance and reset decoder",
)
parser.add_argument(
"--stop-threshold",
default=-1.0,
type=float,
help="Threshold value for detecting the end of speech utterance",
)
parser.add_argument(
"--stop-history-eou",
default=-1,
type=int,
help="Value to determine the response history for endpoint detection",
help="Value (in milliseconds) to detect end of utterance for the 1st pass and generate an intermediate final transcript",
)
parser.add_argument(
"--stop-threshold",
"--stop-threshold-eou",
default=-1.0,
type=float,
help="Threshold value for detecting the end of speech utterance",
help="Threshold value for likelihood of blanks before detecting end of utterance",
)
virajkarandikar marked this conversation as resolved.
Show resolved Hide resolved
return parser

Expand Down
5 changes: 4 additions & 1 deletion riva/client/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,9 @@ def add_endpoint_parameters_to_config(
stop_history: int,
stop_history_eou: int,
stop_threshold: float,
stop_threshold_eou: float,
) -> None:
if not (start_history > 0 or start_threshold > 0 or stop_history > 0 or stop_history_eou > 0 or stop_threshold > 0):
if not (start_history > 0 or start_threshold > 0 or stop_history > 0 or stop_history_eou > 0 or stop_threshold > 0 or stop_threshold_eou > 0):
return

inner_config: rasr.RecognitionConfig = config if isinstance(config, rasr.RecognitionConfig) else config.config
Expand All @@ -146,6 +147,8 @@ def add_endpoint_parameters_to_config(
endpointing_config.stop_history_eou = stop_history_eou
if stop_threshold > 0:
endpointing_config.stop_threshold = stop_threshold
if stop_threshold_eou > 0:
endpointing_config.stop_threshold_eou = stop_threshold_eou
inner_config.endpointing_config.CopyFrom(endpointing_config)


Expand Down
13 changes: 7 additions & 6 deletions scripts/asr/riva_streaming_asr_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,13 @@ def streaming_transcription_worker(
interim_results=True,
)
riva.client.add_endpoint_parameters_to_config(
config,
args.start_history,
args.start_threshold,
args.stop_history,
args.stop_history_eou,
args.stop_threshold
config,
args.start_history,
args.start_threshold,
args.stop_history,
args.stop_history_eou,
args.stop_threshold,
args.stop_threshold_eou
)
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
for _ in range(args.num_iterations):
Expand Down
13 changes: 7 additions & 6 deletions scripts/asr/transcribe_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,13 @@ def main() -> None:
)
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
riva.client.add_endpoint_parameters_to_config(
config,
args.start_history,
args.start_threshold,
args.stop_history,
args.stop_history_eou,
args.stop_threshold
config,
args.start_history,
args.start_threshold,
args.stop_history,
args.stop_history_eou,
args.stop_threshold,
args.stop_threshold_eou
)
sound_callback = None
try:
Expand Down
13 changes: 7 additions & 6 deletions scripts/asr/transcribe_file_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,13 @@ def main() -> None:
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
riva.client.add_speaker_diarization_to_config(config, args.speaker_diarization)
riva.client.add_endpoint_parameters_to_config(
config,
args.start_history,
args.start_threshold,
args.stop_history,
args.stop_history_eou,
args.stop_threshold
config,
args.start_history,
args.start_threshold,
args.stop_history,
args.stop_history_eou,
args.stop_threshold,
args.stop_threshold_eou
)
with args.input_file.open('rb') as fh:
data = fh.read()
Expand Down
13 changes: 7 additions & 6 deletions scripts/asr/transcribe_mic.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,13 @@ def main() -> None:
)
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
riva.client.add_endpoint_parameters_to_config(
config,
args.start_history,
args.start_threshold,
args.stop_history,
args.stop_history_eou,
args.stop_threshold
config,
args.start_history,
args.start_threshold,
args.stop_history,
args.stop_history_eou,
args.stop_threshold,
args.stop_threshold_eou
)
with riva.client.audio_io.MicrophoneStream(
args.sample_rate_hz,
Expand Down