Skip to content

Commit

Permalink
Support setting max speakers for offline diarization (#97)
Browse files Browse the repository at this point in the history
* fix: accept input for max_speaker_count in asr/transcribe_file_offline

* fix: rename input field to diarization_max_speakers

* remove: redundant default value for max_speakers
  • Loading branch information
pskrunner14 authored Sep 20, 2024
1 parent b94b3a9 commit 9a2cd82
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 2 deletions.
6 changes: 6 additions & 0 deletions riva/client/argparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def add_asr_config_argparse_parameters(
action='store_true',
help="Flag that controls if speaker diarization should be performed",
)
parser.add_argument(
"--diarization-max-speakers",
default=3,
type=int,
help="Max number of speakers to detect when performing speaker diarization",
)
parser.add_argument(
"--start-history",
default=-1,
Expand Down
6 changes: 5 additions & 1 deletion riva/client/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,14 @@ def add_audio_file_specs_to_config(
def add_speaker_diarization_to_config(
config: Union[rasr.RecognitionConfig],
diarization_enable: bool,
diarization_max_speakers: int,
) -> None:
inner_config: rasr.RecognitionConfig = config if isinstance(config, rasr.RecognitionConfig) else config.config
if diarization_enable:
diarization_config = rasr.SpeakerDiarizationConfig(enable_speaker_diarization=True)
diarization_config = rasr.SpeakerDiarizationConfig(
enable_speaker_diarization=True,
max_speaker_count=diarization_max_speakers,
)
inner_config.diarization_config.CopyFrom(diarization_config)


Expand Down
2 changes: 1 addition & 1 deletion scripts/asr/transcribe_file_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def main() -> None:
enable_word_time_offsets=args.word_time_offsets or args.speaker_diarization,
)
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
riva.client.add_speaker_diarization_to_config(config, args.speaker_diarization)
riva.client.add_speaker_diarization_to_config(config, args.speaker_diarization, args.diarization_max_speakers)
riva.client.add_endpoint_parameters_to_config(
config,
args.start_history,
Expand Down

0 comments on commit 9a2cd82

Please sign in to comment.