Skip to content

Commit

Permalink
Standalone diarization+ASR evaluation script (NVIDIA#5439)
Browse files Browse the repository at this point in the history
* first commit on eval_diar_with_asr.py

Signed-off-by: Taejin Park <[email protected]>

* Add a standalone diarization-ASR evaluation transcript

Signed-off-by: Taejin Park <[email protected]>

* Fixed examples in docstrings

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed staticmethod error

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Added description on eval modes

Signed-off-by: Taejin Park <[email protected]>

* adding diar_infer_general.yaml

Signed-off-by: Taejin Park <[email protected]>

* fix msdd_model in general yaml file

Signed-off-by: Taejin Park <[email protected]>

* fixed errors in yaml file

Signed-off-by: Taejin Park <[email protected]>

* combine into 1 commit

Signed-off-by: Taejin Park <[email protected]>

* Added description on eval modes

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add MoE support for T5 model (w/o expert parallel) (NVIDIA#5409)

* clean

Signed-off-by: Abhinav Khattar <[email protected]>

* kwarg ref

Signed-off-by: Abhinav Khattar <[email protected]>

* fix

Signed-off-by: Abhinav Khattar <[email protected]>

* fix

Signed-off-by: Abhinav Khattar <[email protected]>

* test

Signed-off-by: Abhinav Khattar <[email protected]>

* test

Signed-off-by: Abhinav Khattar <[email protected]>

* test

Signed-off-by: Abhinav Khattar <[email protected]>

* test

Signed-off-by: Abhinav Khattar <[email protected]>

* test

Signed-off-by: Abhinav Khattar <[email protected]>

* test

Signed-off-by: Abhinav Khattar <[email protected]>

* extra args

Signed-off-by: Abhinav Khattar <[email protected]>

* test

Signed-off-by: Abhinav Khattar <[email protected]>

* rm prints

Signed-off-by: Abhinav Khattar <[email protected]>

* style

Signed-off-by: Abhinav Khattar <[email protected]>

* review comments

Signed-off-by: Abhinav Khattar <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* review comments

Signed-off-by: Abhinav Khattar <[email protected]>

* review comments

Signed-off-by: Abhinav Khattar <[email protected]>

* fix

Signed-off-by: Abhinav Khattar <[email protected]>

Signed-off-by: Abhinav Khattar <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Fix args (NVIDIA#5410) (NVIDIA#5416)

Signed-off-by: MaximumEntropy <[email protected]>

Signed-off-by: MaximumEntropy <[email protected]>

Signed-off-by: MaximumEntropy <[email protected]>
Co-authored-by: Sandeep Subramanian <[email protected]>

* Fix for concat map dataset (NVIDIA#5133)

* change for concat map dataset

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Exhaust longest dataset

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: 1-800-BAD-CODE <>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Oleksii Kuchaiev <[email protected]>
Co-authored-by: Sandeep Subramanian <[email protected]>

* Add temporary fix for CUDA issue in Dockerfile (NVIDIA#5421) (NVIDIA#5422)

Signed-off-by: Yu Yao <[email protected]>

Signed-off-by: Yu Yao <[email protected]>

Signed-off-by: Yu Yao <[email protected]>
Co-authored-by: yaoyu-33 <[email protected]>

* Fix GPT generation when using sentencepiece tokenizer (NVIDIA#5413) (NVIDIA#5428)

* Fix

Signed-off-by: MaximumEntropy <[email protected]>

* Fix

Signed-off-by: MaximumEntropy <[email protected]>

Signed-off-by: MaximumEntropy <[email protected]>
Co-authored-by: Yi Dong <[email protected]>
Co-authored-by: Oleksii Kuchaiev <[email protected]>

Signed-off-by: MaximumEntropy <[email protected]>
Co-authored-by: Sandeep Subramanian <[email protected]>
Co-authored-by: Yi Dong <[email protected]>
Co-authored-by: Oleksii Kuchaiev <[email protected]>

* Support for finetuning and finetuning inference with .ckpt files & batch size refactoring (NVIDIA#5339)

* Initial refactor

Signed-off-by: MaximumEntropy <[email protected]>

* Resolve config before passing to load_from_checkpoint

Signed-off-by: MaximumEntropy <[email protected]>

* Fixes for model parallel and nemo restore

Signed-off-by: MaximumEntropy <[email protected]>

* Fixes for eval

Signed-off-by: MaximumEntropy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert config changes

Signed-off-by: MaximumEntropy <[email protected]>

* Refactor

Signed-off-by: MaximumEntropy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix typo

Signed-off-by: MaximumEntropy <[email protected]>

* Remove comments

Signed-off-by: MaximumEntropy <[email protected]>

* Minor

Signed-off-by: MaximumEntropy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix validation reconfiguration

Signed-off-by: MaximumEntropy <[email protected]>

* Remove old comment

Signed-off-by: MaximumEntropy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixes for test_ds

Signed-off-by: MaximumEntropy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: MaximumEntropy <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Revert "Add temporary fix for CUDA issue in Dockerfile (NVIDIA#5421)" (NVIDIA#5431) (NVIDIA#5432)

This reverts commit 0718b17.

Co-authored-by: yaoyu-33 <[email protected]>

* [ITN] fix year date graph, cardinals extension for hundreds (NVIDIA#5435)

* wip

Signed-off-by: ekmb <[email protected]>

* add lociko's hundreds extension for cardinals

Signed-off-by: ekmb <[email protected]>

* add optional end

Signed-off-by: ekmb <[email protected]>

* restart ci

Signed-off-by: ekmb <[email protected]>

Signed-off-by: ekmb <[email protected]>

* update doc in terms of get_label for lang id model (NVIDIA#5366)

* reflect PR 5278 ion doc

Signed-off-by: fayejf <[email protected]>

* reflect comment

Signed-off-by: fayejf <[email protected]>

Signed-off-by: fayejf <[email protected]>

* Revert workaround for T5 that sets number of workers to 0 & sync_batch_comm=False (NVIDIA#5420) (NVIDIA#5433)

* Revert workers workaround

Signed-off-by: MaximumEntropy <[email protected]>

* Fix in config

Signed-off-by: MaximumEntropy <[email protected]>

* Fix

Signed-off-by: MaximumEntropy <[email protected]>

Signed-off-by: MaximumEntropy <[email protected]>
Co-authored-by: Oleksii Kuchaiev <[email protected]>

Signed-off-by: MaximumEntropy <[email protected]>
Co-authored-by: Sandeep Subramanian <[email protected]>
Co-authored-by: Oleksii Kuchaiev <[email protected]>

* Fixed bug in notebook (NVIDIA#5382) (NVIDIA#5394)

Signed-off-by: Virginia Adams <[email protected]>

Signed-off-by: Virginia Adams <[email protected]>

Signed-off-by: Virginia Adams <[email protected]>
Co-authored-by: Virginia Adams <[email protected]>

* Fixing bug in Megatron BERT when loss mask is all zeros (NVIDIA#5424)

* Fixing bug when loss mask is fully zero

Signed-off-by: Shanmugam Ramasamy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update megatron_bert_model.py

Signed-off-by: Shanmugam Ramasamy <[email protected]>

* Update dataset_utils.py

Signed-off-by: Shanmugam Ramasamy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update dataset_utils.py

Signed-off-by: Shanmugam Ramasamy <[email protected]>

* Update dataset_utils.py

Signed-off-by: Shanmugam Ramasamy <[email protected]>

Signed-off-by: Shanmugam Ramasamy <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sandeep Subramanian <[email protected]>

* Use updated API for overlapping grad sync with pipeline parallelism (NVIDIA#5236)

Signed-off-by: Tim Moon <[email protected]>

Signed-off-by: Tim Moon <[email protected]>

* support to disable sequence length + 1 input tokens for each sample in MegatronGPT (NVIDIA#5363)

* support to disable sequence length + 1 input tokens for MegatronGPT

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: Anmol Gupta <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sandeep Subramanian <[email protected]>

* [TTS] Create script for processing TTS training audio (NVIDIA#5262)

* Create script for processing TTS training audio
* Update VAD trimming logic
* Remove unused import

Signed-off-by: Ryan <[email protected]>

* [TTS] remove useless logic for set_tokenizer. (NVIDIA#5430)

Signed-off-by: Xuesong Yang <[email protected]>

* Fix setting up of `ReduceLROnPlateau` learning rate scheduler (NVIDIA#5444)

* Fix tests

Signed-off-by: PeganovAnton <[email protected]>

* Add accidentally lost changes

Signed-off-by: PeganovAnton <[email protected]>

Signed-off-by: PeganovAnton <[email protected]>

* Create codeql.yml (NVIDIA#5445)

Signed-off-by: Somshubra Majumdar <[email protected]>

Signed-off-by: Somshubra Majumdar <[email protected]>

* Fix for getting tokenizer in character-based ASR models when using tarred dataset (NVIDIA#5442)

Signed-off-by: Jonghwan Hyeon <[email protected]>

Signed-off-by: Jonghwan Hyeon <[email protected]>

* Combine 5 commits

adding diar_infer_general.yaml

Signed-off-by: Taejin Park <[email protected]>

Update codeql.yml

Signed-off-by: Somshubra Majumdar <[email protected]>

Update codeql.yml

Signed-off-by: Somshubra Majumdar <[email protected]>

fix msdd_model in general yaml file

Signed-off-by: Taejin Park <[email protected]>

fixed errors in yaml file

Signed-off-by: Taejin Park <[email protected]>

* moved eval_der function and fixed tqdm options

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Changed minor error in docstrings

Signed-off-by: Taejin Park <[email protected]>

* removed score_labels and changed leave=True

Signed-off-by: Taejin Park <[email protected]>

Signed-off-by: Taejin Park <[email protected]>
Signed-off-by: Abhinav Khattar <[email protected]>
Signed-off-by: MaximumEntropy <[email protected]>
Signed-off-by: Yu Yao <[email protected]>
Signed-off-by: ekmb <[email protected]>
Signed-off-by: fayejf <[email protected]>
Signed-off-by: Virginia Adams <[email protected]>
Signed-off-by: Shanmugam Ramasamy <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Ryan <[email protected]>
Signed-off-by: Xuesong Yang <[email protected]>
Signed-off-by: PeganovAnton <[email protected]>
Signed-off-by: Somshubra Majumdar <[email protected]>
Signed-off-by: Jonghwan Hyeon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Abhinav Khattar <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Sandeep Subramanian <[email protected]>
Co-authored-by: Shane Carroll <[email protected]>
Co-authored-by: Oleksii Kuchaiev <[email protected]>
Co-authored-by: yaoyu-33 <[email protected]>
Co-authored-by: Yi Dong <[email protected]>
Co-authored-by: Evelina <[email protected]>
Co-authored-by: fayejf <[email protected]>
Co-authored-by: Virginia Adams <[email protected]>
Co-authored-by: Shanmugam Ramasamy <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: anmolgupt <[email protected]>
Co-authored-by: Anmol Gupta <[email protected]>
Co-authored-by: Ryan Langman <[email protected]>
Co-authored-by: Xuesong Yang <[email protected]>
Co-authored-by: PeganovAnton <[email protected]>
Co-authored-by: Somshubra Majumdar <[email protected]>
Co-authored-by: Jonghwan Hyeon <[email protected]>
Signed-off-by: shane carroll <[email protected]>
  • Loading branch information
21 people committed Nov 26, 2022
1 parent 6831170 commit ac58218
Show file tree
Hide file tree
Showing 10 changed files with 749 additions and 3,189 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,32 @@ def main(cfg):
# If RTTM is provided and DER evaluation
if diar_score is not None:
metric, mapping_dict, _ = diar_score
der_results = asr_diar_offline.gather_eval_results(metric, mapping_dict, trans_info_dict)
wer_results = asr_diar_offline.evaluate(trans_info_dict)
asr_diar_offline.print_errors(der_results, wer_results)

# Get session-level diarization error rate and speaker counting error
der_results = OfflineDiarWithASR.gather_eval_results(
diar_score=diar_score,
audio_rttm_map_dict=asr_diar_offline.AUDIO_RTTM_MAP,
trans_info_dict=trans_info_dict,
root_path=asr_diar_offline.root_path,
)

# Calculate WER and cpWER if reference CTM files exist
wer_results = OfflineDiarWithASR.evaluate(
hyp_trans_info_dict=trans_info_dict,
audio_file_list=asr_diar_offline.audio_file_list,
ref_ctm_file_list=asr_diar_offline.ctm_file_list,
)

# Print average DER, WER and cpWER
OfflineDiarWithASR.print_errors(der_results=der_results, wer_results=wer_results)

# Save detailed session-level evaluation results in `root_path`.
OfflineDiarWithASR.write_session_level_result_in_csv(
der_results=der_results,
wer_results=wer_results,
root_path=asr_diar_offline.root_path,
csv_columns=asr_diar_offline.csv_columns,
)


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# This YAML file is created for all types of offline speaker diarization inference tasks in `<NeMo git root>/example/speaker_tasks/diarization` folder.
# The inference parameters for VAD, speaker embedding extractor, clustering module, MSDD module, ASR decoder are all included in this YAML file.
# All the keys under `diarizer` key (`vad`, `speaker_embeddings`, `clustering`, `msdd_model`, `asr`) can be selectively used for its own purpose and also can be ignored if the module is not used.
# The configurations in this YAML file is optimized to show balanced performances on various types of domain. VAD is optimized on multilingual ASR datasets and diarizer is optimized on DIHARD3 development set.
# An example line in an input manifest file (`.json` format):
# {"audio_filepath": "/path/to/audio_file", "offset": 0, "duration": null, "label": "infer", "text": "-", "num_speakers": null, "rttm_filepath": "/path/to/rttm/file", "uem_filepath": "/path/to/uem/file"}
name: &name "ClusterDiarizer"

num_workers: 1
sample_rate: 16000
batch_size: 64

diarizer:
manifest_filepath: ???
out_dir: ???
oracle_vad: False # If True, uses RTTM files provided in the manifest file to get speech activity (VAD) timestamps
collar: 0.25 # Collar value for scoring
ignore_overlap: True # Consider or ignore overlap segments while scoring

vad:
model_path: vad_multilingual_marblenet # .nemo local model path or pretrained VAD model name
external_vad_manifest: null # This option is provided to use external vad and provide its speech activity labels for speaker embeddings extraction. Only one of model_path or external_vad_manifest should be set

parameters: # Tuned by detection error rate (false alarm + miss) on multilingual ASR evaluation datasets
window_length_in_sec: 0.63 # Window length in sec for VAD context input
shift_length_in_sec: 0.08 # Shift length in sec for generate frame level VAD prediction
smoothing: False # False or type of smoothing method (eg: median)
overlap: 0.5 # Overlap ratio for overlapped mean/median smoothing filter
onset: 0.5 # Onset threshold for detecting the beginning and end of a speech
offset: 0.3 # Offset threshold for detecting the end of a speech
pad_onset: 0.2 # Adding durations before each speech segment
pad_offset: 0.2 # Adding durations after each speech segment
min_duration_on: 0.5 # Threshold for small non_speech deletion
min_duration_off: 0.5 # Threshold for short speech segment deletion
filter_speech_first: True

speaker_embeddings:
model_path: titanet_large # .nemo local model path or pretrained model name (titanet_large, ecapa_tdnn or speakerverification_speakernet)
parameters:
window_length_in_sec: [1.9,1.2,0.5] # Window length(s) in sec (floating-point number). either a number or a list. ex) 1.5 or [1.5,1.0,0.5]
shift_length_in_sec: [0.95,0.6,0.25] # Shift length(s) in sec (floating-point number). either a number or a list. ex) 0.75 or [0.75,0.5,0.25]
multiscale_weights: [1,1,1] # Weight for each scale. should be null (for single scale) or a list matched with window/shift scale count. ex) [0.33,0.33,0.33]
save_embeddings: True # If True, save speaker embeddings in pickle format. This should be True if clustering result is used for other models, such as `msdd_model`.

clustering:
parameters:
oracle_num_speakers: False # If True, use num of speakers value provided in manifest file.
max_num_speakers: 8 # Max number of speakers for each recording. If an oracle number of speakers is passed, this value is ignored.
enhanced_count_thres: 80 # If the number of segments is lower than this number, enhanced speaker counting is activated.
max_rp_threshold: 0.25 # Determines the range of p-value search: 0 < p <= max_rp_threshold.
sparse_search_volume: 10 # The higher the number, the more values will be examined with more time.
maj_vote_spk_count: False # If True, take a majority vote on multiple p-values to estimate the number of speakers.

msdd_model:
model_path: null # .nemo local model path or pretrained model name for multiscale diarization decoder (MSDD)
parameters:
use_speaker_model_from_ckpt: True # If True, use speaker embedding model in checkpoint. If False, the provided speaker embedding model in config will be used.
infer_batch_size: 25 # Batch size for MSDD inference.
sigmoid_threshold: [0.7] # Sigmoid threshold for generating binarized speaker labels. The smaller the more generous on detecting overlaps.
seq_eval_mode: False # If True, use oracle number of speaker and evaluate F1 score for the given speaker sequences. Default is False.
split_infer: True # If True, break the input audio clip to short sequences and calculate cluster average embeddings for inference.
diar_window_length: 50 # The length of split short sequence when split_infer is True.
overlap_infer_spk_limit: 5 # If the estimated number of speakers are larger than this number, overlap speech is not estimated.

asr:
model_path: null # Provide NGC cloud ASR model name. stt_en_conformer_ctc_* models are recommended for diarization purposes.
parameters:
asr_based_vad: False # if True, speech segmentation for diarization is based on word-timestamps from ASR inference.
asr_based_vad_threshold: 1.0 # Threshold (in sec) that caps the gap between two words when generating VAD timestamps using ASR based VAD.
asr_batch_size: null # Batch size can be dependent on each ASR model. Default batch sizes are applied if set to null.
decoder_delay_in_sec: null # Native decoder delay. null is recommended to use the default values for each ASR model.
word_ts_anchor_offset: null # Offset to set a reference point from the start of the word. Recommended range of values is [-0.05 0.2].
word_ts_anchor_pos: "start" # Select which part of the word timestamp we want to use. The options are: 'start', 'end', 'mid'.
fix_word_ts_with_VAD: False # Fix the word timestamp using VAD output. You must provide a VAD model to use this feature.
colored_text: False # If True, use colored text to distinguish speakers in the output transcript.
print_time: True # If True, the start and end time of each speaker turn is printed in the output transcript.
break_lines: False # If True, the output transcript breaks the line to fix the line width (default is 90 chars)

ctc_decoder_parameters: # Optional beam search decoder (pyctcdecode)
pretrained_language_model: null # KenLM model file: .arpa model file or .bin binary file.
beam_width: 32
alpha: 0.5
beta: 2.5

realigning_lm_parameters: # Experimental feature
arpa_language_model: null # Provide a KenLM language model in .arpa format.
min_number_of_words: 3 # Min number of words for the left context.
max_number_of_words: 10 # Max number of words for the right context.
logprob_diff_threshold: 1.2 # The threshold for the difference between two log probability values from two hypotheses.

48 changes: 48 additions & 0 deletions nemo/collections/asr/metrics/der.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,54 @@ def score_labels(
return None


def evaluate_der(audio_rttm_map_dict, all_reference, all_hypothesis, diar_eval_mode='all'):
"""
Evaluate with a selected diarization evaluation scheme
AUDIO_RTTM_MAP (dict):
Dictionary containing information provided from manifestpath
all_reference (list[uniq_name,annotation]):
reference annotations for score calculation
all_hypothesis (list[uniq_name,annotation]):
hypothesis annotations for score calculation
diar_eval_mode (str):
Diarization evaluation modes
diar_eval_mode == "full":
DIHARD challenge style evaluation, the most strict way of evaluating diarization
(collar, ignore_overlap) = (0.0, False)
diar_eval_mode == "fair":
Evaluation setup used in VoxSRC challenge
(collar, ignore_overlap) = (0.25, False)
diar_eval_mode == "forgiving":
Traditional evaluation setup
(collar, ignore_overlap) = (0.25, True)
diar_eval_mode == "all":
Compute all three modes (default)
"""
eval_settings = []
if diar_eval_mode == "full":
eval_settings = [(0.0, False)]
elif diar_eval_mode == "fair":
eval_settings = [(0.25, False)]
elif diar_eval_mode == "forgiving":
eval_settings = [(0.25, True)]
elif diar_eval_mode == "all":
eval_settings = [(0.0, False), (0.25, False), (0.25, True)]
else:
raise ValueError("`diar_eval_mode` variable contains an unsupported value")

for collar, ignore_overlap in eval_settings:
diar_score = score_labels(
AUDIO_RTTM_MAP=audio_rttm_map_dict,
all_reference=all_reference,
all_hypothesis=all_hypothesis,
collar=collar,
ignore_overlap=ignore_overlap,
)
return diar_score


def calculate_session_cpWER_bruteforce(spk_hypothesis: List[str], spk_reference: List[str]) -> Tuple[float, str, str]:
"""
Calculate cpWER with actual permutations in brute-force way when LSA algorithm cannot deliver the correct result.
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/models/clustering_diarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def _run_vad(self, manifest_file):
data.append(get_uniqname_from_filepath(file))

status = get_vad_stream_status(data)
for i, test_batch in enumerate(tqdm(self._vad_model.test_dataloader(), desc='vad', leave=False)):
for i, test_batch in enumerate(tqdm(self._vad_model.test_dataloader(), desc='vad', leave=True)):
test_batch = [x.to(self._device) for x in test_batch]
with autocast():
log_probs = self._vad_model(input_signal=test_batch[0], input_signal_length=test_batch[1])
Expand Down Expand Up @@ -342,7 +342,7 @@ def _extract_embeddings(self, manifest_file: str, scale_idx: int, num_scales: in

all_embs = torch.empty([0])
for test_batch in tqdm(
self._speaker_model.test_dataloader(), desc=f'[{scale_idx}/{num_scales}] extract embeddings', leave=False
self._speaker_model.test_dataloader(), desc=f'[{scale_idx+1}/{num_scales}] extract embeddings', leave=True
):
test_batch = [x.to(self._device) for x in test_batch]
audio_signal, audio_signal_len, labels, slices = test_batch
Expand Down
Loading

0 comments on commit ac58218

Please sign in to comment.