From 2d529829a0f6711f5c9eb7dd8d0687b8f9853bff Mon Sep 17 00:00:00 2001 From: Zhengyang Chen Date: Fri, 14 Jul 2023 16:03:05 +0800 Subject: [PATCH] [examples] Add SRE16 recipe. (#177) * init sre recipe * [examples] add sre recipe * Update README.md * [examples] add sre recipe, delete used files * merge from master * update specify sample number each epoch * remove trailing whitespace * fix the repeat read dataset problem in the evaluation process and update results * fix the repeat read dataset problem in the evaluation process and update results * update as hongji mentioned * Update README.md * Update sre recipe README.md * Update recipe part in README.md --- README.md | 7 +- examples/sre/v2/README.md | 21 ++ examples/sre/v2/conf/resnet.yaml | 81 +++++++ examples/sre/v2/local/extract_sre.sh | 95 ++++++++ examples/sre/v2/local/filter_utt_accd_dur.py | 36 +++ examples/sre/v2/local/generate_sre_aug.py | 57 +++++ examples/sre/v2/local/make_system_sad.py | 134 ++++++++++++ examples/sre/v2/local/prepare_data.sh | 92 ++++++++ examples/sre/v2/local/score.sh | 58 +++++ examples/sre/v2/local/utt2voice_duration.py | 36 +++ examples/sre/v2/path.sh | 5 + examples/sre/v2/run.sh | 123 +++++++++++ examples/sre/v2/tools | 1 + examples/sre/v2/wespeaker | 1 + tools/extract_embedding.sh | 10 +- tools/filter_scp.pl | 87 ++++++++ tools/fix_data_dir.sh | 217 +++++++++++++++++++ tools/make_raw_list.py | 35 ++- tools/make_shard_list.py | 108 ++++++++- tools/spk2utt_to_utt2spk.pl | 27 +++ wespeaker/bin/extract.py | 5 +- wespeaker/bin/train.py | 2 +- wespeaker/dataset/processor.py | 22 +- 23 files changed, 1240 insertions(+), 20 deletions(-) create mode 100644 examples/sre/v2/README.md create mode 100644 examples/sre/v2/conf/resnet.yaml create mode 100755 examples/sre/v2/local/extract_sre.sh create mode 100644 examples/sre/v2/local/filter_utt_accd_dur.py create mode 100644 examples/sre/v2/local/generate_sre_aug.py create mode 100644 examples/sre/v2/local/make_system_sad.py create mode 100755 examples/sre/v2/local/prepare_data.sh create mode 100755 examples/sre/v2/local/score.sh create mode 100644 examples/sre/v2/local/utt2voice_duration.py create mode 100755 examples/sre/v2/path.sh create mode 100755 examples/sre/v2/run.sh create mode 120000 examples/sre/v2/tools create mode 120000 examples/sre/v2/wespeaker create mode 100755 tools/filter_scp.pl create mode 100755 tools/fix_data_dir.sh create mode 100755 tools/spk2utt_to_utt2spk.pl diff --git a/README.md b/README.md index a773d1a6..3aed4e14 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,7 @@ pip3 install wespeakerruntime ``` ## 🔥 News +* 2023.07.14: Support the [NIST SRE16 recipe](https://www.nist.gov/itl/iad/mig/speaker-recognition-evaluation-2016), see [#177](https://github.com/wenet-e2e/wespeaker/pull/177). * 2023.07.10: Support the [Self-Supervised Learning recipe](https://github.com/wenet-e2e/wespeaker/tree/master/examples/voxceleb/v3) on Voxceleb, including [DINO](https://openaccess.thecvf.com/content/ICCV2021/papers/Caron_Emerging_Properties_in_Self-Supervised_Vision_Transformers_ICCV_2021_paper.pdf), [MoCo](https://openaccess.thecvf.com/content_CVPR_2020/papers/He_Momentum_Contrast_for_Unsupervised_Visual_Representation_Learning_CVPR_2020_paper.pdf) and [SimCLR](http://proceedings.mlr.press/v119/chen20j/chen20j.pdf), see [#180](https://github.com/wenet-e2e/wespeaker/pull/180). * 2023.06.30: Support the [SphereFace2](https://ieeexplore.ieee.org/abstract/document/10094954) loss function, with better performance and noisy robust in comparison with the ArcMargin Softmax, see [#173](https://github.com/wenet-e2e/wespeaker/pull/173). @@ -44,14 +45,16 @@ pip3 install wespeakerruntime ## Recipes * [VoxCeleb](https://github.com/wenet-e2e/wespeaker/tree/master/examples/voxceleb): Speaker Verification recipe on the [VoxCeleb dataset](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/) - * 🔥 UPDATE 2023.07.10: We support self-supervised learning recipe on Voxceleb! Achiving **2.627%** (ECAPA_TDNN_GLOB_c1024) EER on vox1-O-clean test set without any labels. - * 🔥 UPDATE 2022.10.31: We support deep r-vector up to the 293-layer version! Achiving **0.447%/0.043** EER/mindcf on vox1-O-clean test set + * 🔥 UPDATE 2023.07.10: We support self-supervised learning recipe on Voxceleb! Achieving **2.627%** (ECAPA_TDNN_GLOB_c1024) EER on vox1-O-clean test set without any labels. + * 🔥 UPDATE 2022.10.31: We support deep r-vector up to the 293-layer version! Achieving **0.447%/0.043** EER/mindcf on vox1-O-clean test set * 🔥 UPDATE 2022.07.19: We apply the same setups as the CNCeleb recipe, and obtain SOTA performance considering the open-source systems - EER/minDCF on vox1-O-clean test set are **0.723%/0.069** (ResNet34) and **0.728%/0.099** (ECAPA_TDNN_GLOB_c1024), after LM fine-tuning and AS-Norm * [CNCeleb](https://github.com/wenet-e2e/wespeaker/tree/master/examples/cnceleb/v2): Speaker Verification recipe on the [CnCeleb dataset](http://cnceleb.org/) * 🔥 UPDATE 2022.10.31: 221-layer ResNet achieves **5.655%/0.330** EER/minDCF * 🔥 UPDATE 2022.07.12: We migrate the winner system of CNSRC 2022 [report](https://aishell-cnsrc.oss-cn-hangzhou.aliyuncs.com/T082.pdf) [slides](https://aishell-cnsrc.oss-cn-hangzhou.aliyuncs.com/T082-ZhengyangChen.pdf) - EER/minDCF reduction from 8.426%/0.487 to **6.492%/0.354** after large margin fine-tuning and AS-Norm +* [NIST SRE16](https://github.com/wenet-e2e/wespeaker/tree/master/examples/sre/v2): Speaker Verification recipe for the [2016 NIST Speaker Recognition Evaluation Plan](https://www.nist.gov/itl/iad/mig/speaker-recognition-evaluation-2016). Similar recipe can be found in [Kaldi](https://github.com/kaldi-asr/kaldi/tree/master/egs/sre16). + * 🔥 UPDATE 2023.07.14: We support NIST SRE16 recipe. After PLDA adaptation, we achieved 6.608%, 10.01%, and 2.974% EER on trial Pooled, Tagalog, and Cantonese, respectively. * [VoxConverse](https://github.com/wenet-e2e/wespeaker/tree/master/examples/voxconverse): Diarization recipe on the [VoxConverse dataset](https://www.robots.ox.ac.uk/~vgg/data/voxconverse/) ## Support List: diff --git a/examples/sre/v2/README.md b/examples/sre/v2/README.md new file mode 100644 index 00000000..8f97974a --- /dev/null +++ b/examples/sre/v2/README.md @@ -0,0 +1,21 @@ +## Results for SRE16 + +* Setup: fbank40, num_frms200, epoch150, Softmax, aug_prob0.6 +* Scoring: cosine & PLDA & PLDA Adaptation +* Metric: EER(%) + +Without PLDA training data augmentation: +| Model | Params | Backend | Pooled | Tagalog | Cantonese | +|:------|:------:|:------------:|:------------:|:------------:|:------------:| +| ResNet34-TSTP-emb256 | 6.63M | Cosine | 15.4 | 19.82 | 10.39 | +| | | PLDA | 9.36 | 14.26 | 4.513 | +| | | Adapt PLDA | 6.608 | 10.01 | 2.974 | + +With PLDA training data augmentation: +| Model | Params | Backend | Pooled | Tagalog | Cantonese | +|:------|:------:|:------------:|:------------:|:------------:|:------------:| +| ResNet34-TSTP-emb256 | 6.63M | Cosine | 15.4 | 19.82 | 10.39 | +| | | PLDA | 8.944 | 13.54 | 4.462 | +| | | Adapt PLDA | 6.543 | 9.666 | 3.254 | + +* 🔥 UPDATE 2023.07.14: Support the [NIST SRE16 recipe](https://www.nist.gov/itl/iad/mig/speaker-recognition-evaluation-2016), see [#177](https://github.com/wenet-e2e/wespeaker/pull/177). diff --git a/examples/sre/v2/conf/resnet.yaml b/examples/sre/v2/conf/resnet.yaml new file mode 100644 index 00000000..b50f4ce2 --- /dev/null +++ b/examples/sre/v2/conf/resnet.yaml @@ -0,0 +1,81 @@ +### train configuration + +exp_dir: exp/ResNet34-TSTP-emb256-fbank40-num_frms200-aug0.6-spFalse-saFalse-Softmax-SGD-epoch150 +gpus: "[0,1]" +num_avg: 10 +enable_amp: False # whether enable automatic mixed precision training + +seed: 42 +num_epochs: 150 +save_epoch_interval: 5 # save model every 5 epochs +log_batch_interval: 100 # log every 100 batchs + +dataloader_args: + batch_size: 256 + num_workers: 16 + pin_memory: False + prefetch_factor: 8 + drop_last: True + +dataset_args: + # the sample number which will be traversed within one epoch, if the value equals to 0, + # the utterance number in the dataset will be used as the sample_num_per_epoch. + sample_num_per_epoch: 780000 + shuffle: True + shuffle_args: + shuffle_size: 1500 + filter: True + filter_args: + min_num_frames: 100 + max_num_frames: 300 + resample_rate: 8000 + speed_perturb: False + num_frms: 200 + aug_prob: 0.6 # prob to add reverb & noise aug per sample + fbank_args: + num_mel_bins: 40 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + spec_aug: False + spec_aug_args: + num_t_mask: 1 + num_f_mask: 1 + max_t: 10 + max_f: 8 + prob: 0.6 + +model: ResNet34 # ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 +model_init: null +model_args: + feat_dim: 40 + embed_dim: 256 + pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP + two_emb_layer: False +projection_args: + project_type: "softmax" # add_margin, arc_margin, sphere, softmax, arc_margin_intertopk_subcenter + +margin_scheduler: MarginScheduler +margin_update: + initial_margin: 0.0 + final_margin: 0.2 + increase_start_epoch: 20 + fix_start_epoch: 40 + update_margin: True + increase_type: "exp" # exp, linear + +loss: CrossEntropyLoss +loss_args: {} + +optimizer: SGD +optimizer_args: + momentum: 0.9 + nesterov: True + weight_decay: 0.0001 + +scheduler: ExponentialDecrease +scheduler_args: + initial_lr: 0.1 + final_lr: 0.00005 + warm_up_epoch: 6 + warm_from_zero: True diff --git a/examples/sre/v2/local/extract_sre.sh b/examples/sre/v2/local/extract_sre.sh new file mode 100755 index 00000000..c36ce5db --- /dev/null +++ b/examples/sre/v2/local/extract_sre.sh @@ -0,0 +1,95 @@ +#!/bin/bash + +# Copyright (c) 2022 Hongji Wang (jijijiang77@gmail.com) +# 2023 Zhengyang Chen (chenzhengyang117@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +exp_dir='' +model_path='' +nj=4 +gpus="[0,1]" +data_type="shard" # shard/raw/feat +data=data +reverb_data=data/rirs/lmdb +noise_data=data/musan/lmdb +aug_plda_data=0 + +. tools/parse_options.sh +set -e + +if [ $aug_plda_data = 0 ];then + sre_plda_data=sre +else + sre_plda_data=sre_aug +fi + +data_name_array=( + "${sre_plda_data}" + "sre16_major" + "sre16_eval_enroll" + "sre16_eval_test" +) +data_list_path_array=( + "${data}/${sre_plda_data}/${data_type}.list" + "${data}/sre16_major/${data_type}.list" + "${data}/sre16_eval_enroll/${data_type}.list" + "${data}/sre16_eval_test/${data_type}.list" +) +data_scp_path_array=( + "${data}/${sre_plda_data}/wav.scp" + "${data}/sre16_major/wav.scp" + "${data}/sre16_eval_enroll/wav.scp" + "${data}/sre16_eval_test/wav.scp" +) # to count the number of wavs +nj_array=($nj $nj $nj $nj) +batch_size_array=(1 1 1 1) # batch_size of test set must be 1 !!! +num_workers_array=(1 1 1 1) +if [ $aug_plda_data = 0 ];then + aug_prob_array=(0.0 0.0 0.0 0.0) +else + aug_prob_array=(0.67 0.0 0.0 0.0) +fi +count=${#data_name_array[@]} + +for i in $(seq 0 $(($count - 1))); do + wavs_num=$(wc -l ${data_scp_path_array[$i]} | awk '{print $1}') + bash tools/extract_embedding.sh --exp_dir ${exp_dir} \ + --model_path $model_path \ + --data_type ${data_type} \ + --data_list ${data_list_path_array[$i]} \ + --wavs_num ${wavs_num} \ + --store_dir ${data_name_array[$i]} \ + --batch_size ${batch_size_array[$i]} \ + --num_workers ${num_workers_array[$i]} \ + --aug_prob ${aug_prob_array[$i]} \ + --reverb_data ${reverb_data} \ + --noise_data ${noise_data} \ + --nj ${nj_array[$i]} \ + --gpus $gpus +done + +wait + +echo "mean vector of enroll" +python tools/vector_mean.py \ + --spk2utt ${data}/sre16_eval_enroll/spk2utt \ + --xvector_scp $exp_dir/embeddings/sre16_eval_enroll/xvector.scp \ + --spk_xvector_ark $exp_dir/embeddings/sre16_eval_enroll/enroll_spk_xvector.ark + +mkdir -p ${exp_dir}/embeddings/eval +cat ${exp_dir}/embeddings/sre16_eval_enroll/enroll_spk_xvector.scp \ + ${exp_dir}/embeddings/sre16_eval_test/xvector.scp \ + > ${exp_dir}/embeddings/eval/xvector.scp + +echo "Embedding dir is (${exp_dir}/embeddings)." diff --git a/examples/sre/v2/local/filter_utt_accd_dur.py b/examples/sre/v2/local/filter_utt_accd_dur.py new file mode 100644 index 00000000..2541e396 --- /dev/null +++ b/examples/sre/v2/local/filter_utt_accd_dur.py @@ -0,0 +1,36 @@ +# Copyright (c) 2023 Zhengyang Chen +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import fire + + +def main(wav_scp, utt2voice_dur, filter_wav_scp, dur_thres=5.0): + + utt2voice_dur_dict = {} + with open(utt2voice_dur, "r") as f: + for line in f: + utt, dur = line.strip().split() + utt2voice_dur_dict[utt] = float(dur) + + with open(wav_scp, "r") as f, open(filter_wav_scp, "w") as fw: + for line in f: + utt = line.strip().split()[0] + if utt in utt2voice_dur_dict: + if utt2voice_dur_dict[utt] > dur_thres: + fw.write(line) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/examples/sre/v2/local/generate_sre_aug.py b/examples/sre/v2/local/generate_sre_aug.py new file mode 100644 index 00000000..29bda825 --- /dev/null +++ b/examples/sre/v2/local/generate_sre_aug.py @@ -0,0 +1,57 @@ +# Copyright (c) 2023 Zhengyang Chen +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import fire + + +def main(ori_dir, aug_dir, aug_copy_num=2): + + if not os.path.exists(aug_dir): + os.makedirs(aug_dir) + + read_wav_scp = os.path.join(ori_dir, 'wav.scp') + aug_wav_scp = os.path.join(aug_dir, 'wav.scp') + read_utt2spk = os.path.join(ori_dir, 'utt2spk') + aug_utt2spk = os.path.join(aug_dir, 'utt2spk') + read_vad = os.path.join(ori_dir, 'vad') + store_vad = os.path.join(aug_dir, 'vad') + + with open(read_wav_scp, 'r') as f, open(aug_wav_scp, 'w') as wf: + for line in f: + line = line.strip().split() + utt, other_info = line[0], ' '.join(line[1:]) + for i in range(aug_copy_num + 1): + wf.write(utt + '_copy-' + str(i) + ' ' + other_info + '\n') + + with open(read_utt2spk, 'r') as f, open(aug_utt2spk, 'w') as wf: + for line in f: + line = line.strip().split() + utt, spk = line[0], line[1] + for i in range(aug_copy_num + 1): + wf.write(utt + '_copy-' + str(i) + ' ' + spk + '\n') + + with open(read_vad, 'r') as f, open(store_vad, 'w') as wf: + for line in f: + line = line.strip().split() + seg, utt, vad = line[0], line[1], ' '.join(line[2:]) + for i in range(aug_copy_num + 1): + new_seg = seg + '_copy-' + str(i) + new_utt = utt + '_copy-' + str(i) + wf.write(new_seg + ' ' + new_utt + ' ' + vad + '\n') + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/examples/sre/v2/local/make_system_sad.py b/examples/sre/v2/local/make_system_sad.py new file mode 100644 index 00000000..84d7d7df --- /dev/null +++ b/examples/sre/v2/local/make_system_sad.py @@ -0,0 +1,134 @@ +# Copyright (c) 2022 Xu Xiang +# 2023 Zhengyang Chen +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +os.environ["OMP_NUM_THREADS"] = "1" +os.environ["OPENBLAS_NUM_THREADS"] = "1" +os.environ["MKL_NUM_THREADS"] = "1" +os.environ["VECLIB_MAXIMUM_THREADS"] = "1" +os.environ["NUMEXPR_NUM_THREADS"] = "1" + +import sys +import io +import functools +import concurrent.futures +import argparse +import importlib +import torchaudio +import subprocess + +import torch + + +def get_args(): + parser = argparse.ArgumentParser(description='') + parser.add_argument('--repo-path', required=True, + help='VAD model repo path') + parser.add_argument('--scp', required=True, help='wav scp') + parser.add_argument('--min-duration', required=True, + type=float, help='min duration') + args = parser.parse_args() + + return args + + +@functools.lru_cache(maxsize=1) +def load_wav( + wav_rxfilename, +): + """ This function reads audio file and return data in pytorch tensor. + "lru_cache" holds recently loaded audio so that can be called + many times on the same audio file. + OPTIMIZE: controls lru_cache size for random access, + considering memory size + """ + if wav_rxfilename.endswith('|'): + # input piped command + p = subprocess.Popen(wav_rxfilename[:-1], shell=True, + stdout=subprocess.PIPE) + data, samplerate = torchaudio.load(io.BytesIO(p.stdout.read())) + elif wav_rxfilename == '-': + # stdin + data, samplerate = torchaudio.load(sys.stdin) + else: + # normal wav file + data, samplerate = torchaudio.load(wav_rxfilename) + return data.squeeze(0), samplerate + + +def read_scp(scp): + utt_wav_pair = [] + for line in open(scp, 'r'): + segs = line.strip().split() + if len(segs) > 2: + utt, wav = segs[0], ' '.join(segs[1:]) + else: + utt, wav = segs[0], segs[1] + utt_wav_pair.append((utt, wav)) + + return utt_wav_pair + + +def silero_vad(utt_wav_pair, repo_path, min_duration, + sampling_rate=8000, threshold=0.25): + + def module_from_file(module_name, file_path): + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + utils_vad = module_from_file("utils_vad", + os.path.join(repo_path, "utils_vad.py")) + model = utils_vad.init_jit_model( + os.path.join(repo_path, 'files/silero_vad.jit')) + + utt, wav = utt_wav_pair + + wav, sr = load_wav(wav) + assert sr == sampling_rate + speech_timestamps = utils_vad.get_speech_timestamps( + wav, model, sampling_rate=sampling_rate, + threshold=threshold) + + vad_result = "" + for item in speech_timestamps: + begin = item['start'] / sampling_rate + end = item['end'] / sampling_rate + if end - begin >= min_duration: + vad_result += "{}-{:08d}-{:08d} {} {:.3f} {:.3f}\n".format( + utt, int(begin * 1000), int(end * 1000), utt, begin, end) + + return vad_result + + +def main(): + args = get_args() + + vad = functools.partial(silero_vad, + repo_path=args.repo_path, + min_duration=args.min_duration) + utt_wav_pair_list = read_scp(args.scp) + + with concurrent.futures.ProcessPoolExecutor() as executor: + print(''.join(executor.map(vad, utt_wav_pair_list)), end='') + + +if __name__ == '__main__': + torch.set_num_threads(1) + + main() diff --git a/examples/sre/v2/local/prepare_data.sh b/examples/sre/v2/local/prepare_data.sh new file mode 100755 index 00000000..5770152b --- /dev/null +++ b/examples/sre/v2/local/prepare_data.sh @@ -0,0 +1,92 @@ +#!/bin/bash + +# Copyright (c) 2023 Zhengyang Chen (chenzhengyang117@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +stage=-1 +stop_stage=-1 +sre_data_dir= +data=data + +. tools/parse_options.sh || exit 1 + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + mkdir -p external_tools + # Download voice activity detection model pretrained by Silero Team + wget -c https://github.com/snakers4/silero-vad/archive/refs/tags/v4.0.zip -O external_tools/silero-vad-v4.0.zip + unzip -o external_tools/silero-vad-v4.0.zip -d external_tools +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # The meta data for SRE16 should be pre-prepared using Kaldi recipe: + # https://github.com/kaldi-asr/kaldi/tree/master/egs/sre16/v2 + for dset in swbd_sre sre sre16_major sre16_eval_enroll sre16_eval_test; do + mkdir -p ${data}/${dset} + cp ${sre_data_dir}/${dset}/wav.scp ${data}/${dset}/wav.scp + [ -f ${sre_data_dir}/${dset}/utt2spk ] && cp ${sre_data_dir}/${dset}/utt2spk ${data}/${dset}/utt2spk + [ -f ${sre_data_dir}/${dset}/spk2utt ] && cp ${sre_data_dir}/${dset}/spk2utt ${data}/${dset}/spk2utt + done +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "Get vad segmentation for dataset." + # Set VAD min duration + min_duration=0.255 + for dset in swbd_sre sre16_major sre16_eval_enroll sre16_eval_test; do + python3 local/make_system_sad.py \ + --repo-path external_tools/silero-vad-4.0 \ + --scp ${data}/${dset}/wav.scp \ + --min-duration $min_duration > ${data}/${dset}/vad + done + tools/filter_scp.pl -f 2 ${data}/sre/wav.scp ${data}/swbd_sre/vad > ${data}/sre/vad + + # For PLDA training, it is better to augment the training data + python3 local/generate_sre_aug.py --ori_dir ${data}/sre \ + --aug_dir ${data}/sre_aug \ + --aug_copy_num 2 + tools/utt2spk_to_spk2utt.pl ${data}/sre_aug/utt2spk > ${data}/sre_aug/spk2utt + +fi + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + for dset in swbd_sre; do + python3 local/utt2voice_duration.py \ + --vad_file ${data}/${dset}/vad \ + --utt2voice_dur ${data}/${dset}/utt2voice_dur + done +fi + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + # Following the Kaldi recipe: https://github.com/kaldi-asr/kaldi/blob/71f38e62cad01c3078555bfe78d0f3a527422d75/egs/sre16/v2/run.sh#L189 + # We filter out the utterances with duration less than 5s + for dset in swbd_sre; do + python3 local/filter_utt_accd_dur.py \ + --wav_scp ${data}/${dset}/wav.scp \ + --utt2voice_dur ${data}/${dset}/utt2voice_dur \ + --filter_wav_scp ${data}/${dset}/filter_wav.scp \ + --dur_thres 5.0 + mv ${data}/${dset}/wav.scp ${data}/${dset}/wav.scp.bak + mv ${data}/${dset}/filter_wav.scp ${data}/${dset}/wav.scp + done + + # Similarly, following the Kaldi recipe, + # we throw out speakers with fewer than 3 utterances. + for dset in swbd_sre; do + tools/fix_data_dir.sh ${data}/${dset} + cp ${data}/${dset}/spk2utt ${data}/${dset}/spk2utt.bak + awk '{if(NF>2){print $0}}' ${data}/${dset}/spk2utt.bak > ${data}/${dset}/spk2utt + tools/spk2utt_to_utt2spk.pl ${data}/${dset}/spk2utt > ${data}/${dset}/utt2spk + tools/fix_data_dir.sh ${data}/${dset} + done +fi diff --git a/examples/sre/v2/local/score.sh b/examples/sre/v2/local/score.sh new file mode 100755 index 00000000..dbe868bc --- /dev/null +++ b/examples/sre/v2/local/score.sh @@ -0,0 +1,58 @@ +#!/bin/bash + +# Copyright (c) 2022 Chengdong Liang (liangchengdong@mail.nwpu.edu.cn) +# 2023 Zhengyang Chen (chenhzhengyang117@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +exp_dir= +trials="trials trials_tgl trials_yue" +data=data + +stage=-1 +stop_stage=-1 + +. tools/parse_options.sh +. path.sh + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "apply cosine scoring ..." + mkdir -p ${exp_dir}/scores + trials_dir=${data}/trials + for x in $trials; do + echo $x + python wespeaker/bin/score.py \ + --exp_dir ${exp_dir} \ + --eval_scp_path ${exp_dir}/embeddings/eval/xvector.scp \ + --cal_mean True \ + --cal_mean_dir ${exp_dir}/embeddings/sre16_major \ + ${trials_dir}/${x} + done +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + echo "compute metrics (EER/minDCF) ..." + scores_dir=${exp_dir}/scores + for x in $trials; do + python wespeaker/bin/compute_metrics.py \ + --p_target 0.01 \ + --c_fa 1 \ + --c_miss 1 \ + ${scores_dir}/${x}.score \ + 2>&1 | tee -a ${scores_dir}/sre16_cos_result + + echo "compute DET curve ..." + python wespeaker/bin/compute_det.py \ + ${scores_dir}/${x}.score + done +fi diff --git a/examples/sre/v2/local/utt2voice_duration.py b/examples/sre/v2/local/utt2voice_duration.py new file mode 100644 index 00000000..2dd7d1c1 --- /dev/null +++ b/examples/sre/v2/local/utt2voice_duration.py @@ -0,0 +1,36 @@ +# Copyright (c) 2023 Zhengyang Chen +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fire +from collections import OrderedDict + + +def main(vad_file, utt2voice_dur): + utt2voice_dur_dict = OrderedDict() + + with open(vad_file, 'r') as f: + for line in f.readlines(): + segs = line.strip().split() + utt, start, end = segs[-3], float(segs[-2]), float(segs[-1]) + if utt not in utt2voice_dur_dict: + utt2voice_dur_dict[utt] = 0.0 + utt2voice_dur_dict[utt] += end - start + + with open(utt2voice_dur, 'w') as f: + for utt, duration in utt2voice_dur_dict.items(): + f.write('{} {}\n'.format(utt, duration)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/examples/sre/v2/path.sh b/examples/sre/v2/path.sh new file mode 100755 index 00000000..b90a5154 --- /dev/null +++ b/examples/sre/v2/path.sh @@ -0,0 +1,5 @@ +export PATH=$PWD:$PATH + +# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=../../../:$PYTHONPATH diff --git a/examples/sre/v2/run.sh b/examples/sre/v2/run.sh new file mode 100755 index 00000000..810b48b2 --- /dev/null +++ b/examples/sre/v2/run.sh @@ -0,0 +1,123 @@ +#!/bin/bash + +# Copyright 2022 Hongji Wang (jijijiang77@gmail.com) +# 2022 Chengdong Liang (liangchengdong@mail.nwpu.edu.cn) +# 2023 Zhengyang Chen (chenzhengyang117@gmail.com) + +. ./path.sh || exit 1 + +stage=-1 +stop_stage=-1 + +# the sre data should be prepared in kaldi format and stored in the following directory +# only wav.scp, utt2spk and spk2utt files are needed +sre_data_dir=sre_data_dir +data=data +data_type="shard" # shard/raw +# whether augment the PLDA data +aug_plda_data=0 + +config=conf/resnet.yaml +exp_dir=exp/ResNet34-TSTP-emb256-fbank40-num_frms200-aug0.6-spFalse-saFalse-Softmax-SGD-epoch150 +gpus="[0,1]" +num_avg=10 +checkpoint= + +trials="trials trials_tgl trials_yue" + +. tools/parse_options.sh || exit 1 + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "Prepare datasets ..." + ./local/prepare_data.sh --stage 2 --stop_stage 5 --sre_data_dir ${sre_data_dir} --data ${data} +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + echo "Convert train data to ${data_type}..." + for dset in swbd_sre; do + python tools/make_shard_list.py --num_utts_per_shard 1000 \ + --num_threads 16 \ + --prefix shards \ + --shuffle \ + ${data}/$dset/wav.scp ${data}/$dset/utt2spk \ + ${data}/$dset/shards ${data}/$dset/shard.list \ + ${data}/$dset/vad + done + + echo "Convert data for PLDA backend training and evaluation to raw format..." + if [ $aug_plda_data = 0 ];then + sre_plda_data=sre + else + sre_plda_data=sre_aug + fi + for dset in ${sre_plda_data} sre16_major sre16_eval_enroll sre16_eval_test; do + python tools/make_raw_list.py ${data}/$dset/wav.scp \ + ${data}/$dset/utt2spk ${data}/$dset/raw.list \ + ${data}/$dset/vad + + done + # Convert all musan data to LMDB + python tools/make_lmdb.py ${data}/musan/wav.scp ${data}/musan/lmdb + # Convert all rirs data to LMDB + python tools/make_lmdb.py ${data}/rirs/wav.scp ${data}/rirs/lmdb +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "Start training ..." + num_gpus=$(echo $gpus | awk -F ',' '{print NF}') + torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \ + wespeaker/bin/train.py --config $config \ + --exp_dir ${exp_dir} \ + --gpus $gpus \ + --num_avg ${num_avg} \ + --data_type "${data_type}" \ + --train_data ${data}/swbd_sre/${data_type}.list \ + --train_label ${data}/swbd_sre/utt2spk \ + --reverb_data ${data}/rirs/lmdb \ + --noise_data ${data}/musan/lmdb \ + ${checkpoint:+--checkpoint $checkpoint} +fi + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + echo "Do model average ..." + avg_model=$exp_dir/models/avg_model.pt + python wespeaker/bin/average_model.py \ + --dst_model $avg_model \ + --src_path $exp_dir/models \ + --num ${num_avg} + + model_path=$avg_model + if [[ $config == *repvgg*.yaml ]]; then + echo "convert repvgg model ..." + python wespeaker/models/convert_repvgg.py \ + --config $exp_dir/config.yaml \ + --load $avg_model \ + --save $exp_dir/models/convert_model.pt + model_path=$exp_dir/models/convert_model.pt + fi + + echo "Extract embeddings ..." + local/extract_sre.sh \ + --exp_dir $exp_dir --model_path $model_path \ + --nj 32 --gpus $gpus --data_type raw --data ${data} \ + --reverb_data ${data}/rirs/lmdb \ + --noise_data ${data}/musan/lmdb \ + --aug_plda_data ${aug_plda_data} +fi + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + echo "Score ..." + local/score.sh \ + --stage 1 --stop-stage 2 \ + --data ${data} \ + --exp_dir $exp_dir \ + --trials "$trials" +fi + +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + echo "Export the best model ..." + python wespeaker/bin/export_jit.py \ + --config $exp_dir/config.yaml \ + --checkpoint $exp_dir/models/avg_model.pt \ + --output_file $exp_dir/models/final.zip +fi diff --git a/examples/sre/v2/tools b/examples/sre/v2/tools new file mode 120000 index 00000000..c92f4172 --- /dev/null +++ b/examples/sre/v2/tools @@ -0,0 +1 @@ +../../../tools \ No newline at end of file diff --git a/examples/sre/v2/wespeaker b/examples/sre/v2/wespeaker new file mode 120000 index 00000000..900c560b --- /dev/null +++ b/examples/sre/v2/wespeaker @@ -0,0 +1 @@ +../../../wespeaker \ No newline at end of file diff --git a/tools/extract_embedding.sh b/tools/extract_embedding.sh index 6e962d1a..a12704c5 100755 --- a/tools/extract_embedding.sh +++ b/tools/extract_embedding.sh @@ -24,6 +24,9 @@ store_dir= batch_size=1 num_workers=1 nj=4 +reverb_data=data/rirs/lmdb +noise_data=data/musan/lmdb +aug_prob=0.0 gpus="[0,1]" . tools/parse_options.sh @@ -53,6 +56,9 @@ for suffix in $(seq 0 $(($nj - 1))); do --embed_ark ${embed_ark} \ --batch-size ${batch_size} \ --num-workers ${num_workers} \ + --reverb_data ${reverb_data} \ + --noise_data ${noise_data} \ + --aug-prob ${aug_prob} \ >${log_dir}/split_${suffix}.log 2>&1 & done @@ -61,7 +67,7 @@ wait cat ${embed_dir}/xvector_*.scp >${embed_dir}/xvector.scp embed_num=$(wc -l ${embed_dir}/xvector.scp | awk '{print $1}') if [ $embed_num -eq $wavs_num ]; then - echo "Success" | tee ${embed_dir}/extract.result + echo "Successfully extract embedding for ${store_dir}" | tee ${embed_dir}/extract.result else - echo "Fail" | tee ${embed_dir}/extract.result + echo "Failed to extract embedding for ${store_dir}" | tee ${embed_dir}/extract.result fi diff --git a/tools/filter_scp.pl b/tools/filter_scp.pl new file mode 100755 index 00000000..b76d37f4 --- /dev/null +++ b/tools/filter_scp.pl @@ -0,0 +1,87 @@ +#!/usr/bin/env perl +# Copyright 2010-2012 Microsoft Corporation +# Johns Hopkins University (author: Daniel Povey) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This script takes a list of utterance-ids or any file whose first field +# of each line is an utterance-id, and filters an scp +# file (or any file whose "n-th" field is an utterance id), printing +# out only those lines whose "n-th" field is in id_list. The index of +# the "n-th" field is 1, by default, but can be changed by using +# the -f switch + +$exclude = 0; +$field = 1; +$shifted = 0; + +do { + $shifted=0; + if ($ARGV[0] eq "--exclude") { + $exclude = 1; + shift @ARGV; + $shifted=1; + } + if ($ARGV[0] eq "-f") { + $field = $ARGV[1]; + shift @ARGV; shift @ARGV; + $shifted=1 + } +} while ($shifted); + +if(@ARGV < 1 || @ARGV > 2) { + die "Usage: filter_scp.pl [--exclude] [-f ] id_list [in.scp] > out.scp \n" . + "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" . + "Note: only the first field of each line in id_list matters. With --exclude, prints\n" . + "only the lines that were *not* in id_list.\n" . + "Caution: previously, the -f option was interpreted as a zero-based field index.\n" . + "If your older scripts (written before Oct 2014) stopped working and you used the\n" . + "-f option, add 1 to the argument.\n" . + "See also: utils/filter_scp.pl .\n"; +} + + +$idlist = shift @ARGV; +open(F, "<$idlist") || die "Could not open id-list file $idlist"; +while() { + @A = split; + @A>=1 || die "Invalid id-list file line $_"; + $seen{$A[0]} = 1; +} + +if ($field == 1) { # Treat this as special case, since it is common. + while(<>) { + $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field."; + # $1 is what we filter on. + if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) { + print $_; + } + } +} else { + while(<>) { + @A = split; + @A > 0 || die "Invalid scp file line $_"; + @A >= $field || die "Invalid scp file line $_"; + if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) { + print $_; + } + } +} + +# tests: +# the following should print "foo 1" +# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo) +# the following should print "bar 2". +# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2) diff --git a/tools/fix_data_dir.sh b/tools/fix_data_dir.sh new file mode 100755 index 00000000..967a1f42 --- /dev/null +++ b/tools/fix_data_dir.sh @@ -0,0 +1,217 @@ +#!/usr/bin/env bash + +# Directly copied from https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/utils/fix_data_dir.sh +# +# This script makes sure that only the segments present in +# all of "feats.scp", "wav.scp" [if present], segments [if present] +# text, and utt2spk are present in any of them. +# It puts the original contents of data-dir into +# data-dir/.backup + +cmd="$@" + +utt_extra_files= +spk_extra_files= + +. tools/parse_options.sh + +if [ $# != 1 ]; then + echo "Usage: local/data/fix_data_dir.sh " + echo "e.g.: local/data/fix_data_dir.sh data/train" + echo "This script helps ensure that the various files in a data directory" + echo "are correctly sorted and filtered, for example removing utterances" + echo "that have no features (if feats.scp is present)" + exit 1 +fi + +data=$1 + +if [ -f $data/images.scp ]; then + image/fix_data_dir.sh $cmd + exit $? +fi + +mkdir -p $data/.backup + +[ ! -d $data ] && echo "$0: no such directory $data" && exit 1; + +[ ! -f $data/utt2spk ] && echo "$0: no such file $data/utt2spk" && exit 1; + +set -e -o pipefail -u + +tmpdir=$(mktemp -d /tmp/kaldi.XXXX); +trap 'rm -rf "$tmpdir"' EXIT HUP INT PIPE TERM + +export LC_ALL=C + +function check_sorted { + file=$1 + sort -k1,1 -u <$file >$file.tmp + if ! cmp -s $file $file.tmp; then + echo "$0: file $1 is not in sorted order or not unique, sorting it" + mv $file.tmp $file + else + rm $file.tmp + fi +} + +for x in utt2spk spk2utt feats.scp text segments wav.scp cmvn.scp vad.scp \ + reco2file_and_channel spk2gender utt2lang utt2uniq utt2dur reco2dur utt2num_frames; do + if [ -f $data/$x ]; then + cp $data/$x $data/.backup/$x + check_sorted $data/$x + fi +done + + +function filter_file { + filter=$1 + file_to_filter=$2 + cp $file_to_filter ${file_to_filter}.tmp + local/filter_scp.pl $filter ${file_to_filter}.tmp > $file_to_filter + if ! cmp ${file_to_filter}.tmp $file_to_filter >&/dev/null; then + length1=$(cat ${file_to_filter}.tmp | wc -l) + length2=$(cat ${file_to_filter} | wc -l) + if [ $length1 -ne $length2 ]; then + echo "$0: filtered $file_to_filter from $length1 to $length2 lines based on filter $filter." + fi + fi + rm $file_to_filter.tmp +} + +function filter_recordings { + # We call this once before the stage when we filter on utterance-id, and once + # after. + + if [ -f $data/segments ]; then + # We have a segments file -> we need to filter this and the file wav.scp, and + # reco2file_and_utt, if it exists, to make sure they have the same list of + # recording-ids. + + if [ ! -f $data/wav.scp ]; then + echo "$0: $data/segments exists but not $data/wav.scp" + exit 1; + fi + awk '{print $2}' < $data/segments | sort | uniq > $tmpdir/recordings + n1=$(cat $tmpdir/recordings | wc -l) + [ ! -s $tmpdir/recordings ] && \ + echo "Empty list of recordings (bad file $data/segments)?" && exit 1; + local/filter_scp.pl $data/wav.scp $tmpdir/recordings > $tmpdir/recordings.tmp + mv $tmpdir/recordings.tmp $tmpdir/recordings + + + cp $data/segments{,.tmp}; awk '{print $2, $1, $3, $4}' <$data/segments.tmp >$data/segments + filter_file $tmpdir/recordings $data/segments + cp $data/segments{,.tmp}; awk '{print $2, $1, $3, $4}' <$data/segments.tmp >$data/segments + rm $data/segments.tmp + + filter_file $tmpdir/recordings $data/wav.scp + [ -f $data/reco2file_and_channel ] && filter_file $tmpdir/recordings $data/reco2file_and_channel + [ -f $data/reco2dur ] && filter_file $tmpdir/recordings $data/reco2dur + true + fi +} + +function filter_speakers { + # throughout this program, we regard utt2spk as primary and spk2utt as derived, so... + local/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt + + cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers + for s in cmvn.scp spk2gender; do + f=$data/$s + if [ -f $f ]; then + filter_file $f $tmpdir/speakers + fi + done + + filter_file $tmpdir/speakers $data/spk2utt + local/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk + + for s in cmvn.scp spk2gender $spk_extra_files; do + f=$data/$s + if [ -f $f ]; then + filter_file $tmpdir/speakers $f + fi + done +} + +function filter_utts { + cat $data/utt2spk | awk '{print $1}' > $tmpdir/utts + + ! cat $data/utt2spk | sort | cmp - $data/utt2spk && \ + echo "utt2spk is not in sorted order (fix this yourself)" && exit 1; + + ! cat $data/utt2spk | sort -k2 | cmp - $data/utt2spk && \ + echo "utt2spk is not in sorted order when sorted first on speaker-id " && \ + echo "(fix this by making speaker-ids prefixes of utt-ids)" && exit 1; + + ! cat $data/spk2utt | sort | cmp - $data/spk2utt && \ + echo "spk2utt is not in sorted order (fix this yourself)" && exit 1; + + if [ -f $data/utt2uniq ]; then + ! cat $data/utt2uniq | sort | cmp - $data/utt2uniq && \ + echo "utt2uniq is not in sorted order (fix this yourself)" && exit 1; + fi + + maybe_wav= + maybe_reco2dur= + [ ! -f $data/segments ] && maybe_wav=wav.scp # wav indexed by utts only if segments does not exist. + [ -s $data/reco2dur ] && [ ! -f $data/segments ] && maybe_reco2dur=reco2dur # reco2dur indexed by utts + + maybe_utt2dur= + if [ -f $data/utt2dur ]; then + cat $data/utt2dur | \ + awk '{ if (NF == 2 && $2 > 0) { print }}' > $data/utt2dur.ok || exit 1 + maybe_utt2dur=utt2dur.ok + fi + + maybe_utt2num_frames= + if [ -f $data/utt2num_frames ]; then + cat $data/utt2num_frames | \ + awk '{ if (NF == 2 && $2 > 0) { print }}' > $data/utt2num_frames.ok || exit 1 + maybe_utt2num_frames=utt2num_frames.ok + fi + + for x in feats.scp text segments utt2lang $maybe_wav $maybe_utt2dur $maybe_utt2num_frames; do + if [ -f $data/$x ]; then + local/filter_scp.pl $data/$x $tmpdir/utts > $tmpdir/utts.tmp + mv $tmpdir/utts.tmp $tmpdir/utts + fi + done + rm $data/utt2dur.ok 2>/dev/null || true + rm $data/utt2num_frames.ok 2>/dev/null || true + + [ ! -s $tmpdir/utts ] && echo "fix_data_dir.sh: no utterances remained: not proceeding further." && \ + rm $tmpdir/utts && exit 1; + + + if [ -f $data/utt2spk ]; then + new_nutts=$(cat $tmpdir/utts | wc -l) + old_nutts=$(cat $data/utt2spk | wc -l) + if [ $new_nutts -ne $old_nutts ]; then + echo "fix_data_dir.sh: kept $new_nutts utterances out of $old_nutts" + else + echo "fix_data_dir.sh: kept all $old_nutts utterances." + fi + fi + + for x in utt2spk utt2uniq feats.scp vad.scp text segments utt2lang utt2dur utt2num_frames $maybe_wav $maybe_reco2dur $utt_extra_files; do + if [ -f $data/$x ]; then + cp $data/$x $data/.backup/$x + if ! cmp -s $data/$x <( local/filter_scp.pl $tmpdir/utts $data/$x ) ; then + local/filter_scp.pl $tmpdir/utts $data/.backup/$x > $data/$x + fi + fi + done + +} + +filter_recordings +filter_speakers +filter_utts +filter_speakers +filter_recordings + +local/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt + +echo "fix_data_dir.sh: old files are kept in $data/.backup" diff --git a/tools/make_raw_list.py b/tools/make_raw_list.py index 4c093b0f..93a52a2e 100644 --- a/tools/make_raw_list.py +++ b/tools/make_raw_list.py @@ -1,4 +1,5 @@ # Copyright (c) 2022 Binbin Zhang(binbzha@qq.com) +# 2023 Zhengyang Chen(chenzhengyang117@gmail.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,6 +16,7 @@ import argparse import logging import json +import os def get_args(): @@ -22,6 +24,7 @@ def get_args(): parser.add_argument('wav_file', help='wav file') parser.add_argument('utt2spk_file', help='utt2spk file') parser.add_argument('raw_list', help='output raw list file') + parser.add_argument('vad_file', help='vad file', default='non_exist') args = parser.parse_args() return args @@ -36,8 +39,19 @@ def main(): for line in fin: arr = line.strip().split() key = arr[0] # os.path.splitext(arr[0])[0] - assert len(arr) == 2 - wav_table[key] = arr[1] + wav_table[key] = ' '.join(arr[1:]) + + if os.path.exists(args.vad_file): + vad_dict = {} + with open(args.vad_file, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + utt, start, end = arr[-3], arr[-2], arr[-1] + if utt not in vad_dict: + vad_dict[utt] = [] + vad_dict[utt].append((start, end)) + else: + vad_dict = None data = [] with open(args.utt2spk_file, 'r', encoding='utf8') as fin: @@ -47,11 +61,22 @@ def main(): spk = arr[1] assert key in wav_table wav = wav_table[key] - data.append((key, spk, wav)) + if vad_dict is None: + data.append((key, spk, wav)) + else: + if key not in vad_dict: + continue + vad = vad_dict[key] + data.append((key, spk, wav, vad)) with open(args.raw_list, 'w', encoding='utf8') as fout: - for key, spk, wav in data: - line = dict(key=key, spk=spk, wav=wav) + for utt_info in data: + if len(utt_info) == 4: + key, spk, wav, vad = utt_info + line = dict(key=key, spk=spk, wav=wav, vad=vad) + else: + key, spk, wav = utt_info + line = dict(key=key, spk=spk, wav=wav) json_line = json.dumps(line, ensure_ascii=False) fout.write(json_line + '\n') diff --git a/tools/make_shard_list.py b/tools/make_shard_list.py index 6524822f..93c28ea8 100644 --- a/tools/make_shard_list.py +++ b/tools/make_shard_list.py @@ -1,4 +1,5 @@ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# 2023 Shanghai Jiaotong University (authors: Zhengyang Chen) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,22 +21,97 @@ import tarfile import time import multiprocessing +import subprocess +from scipy.io import wavfile +import numpy as np +import struct AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma']) +def write_wav_to_bytesio(audio_data, sample_rate): + audio_data = audio_data.astype(np.int16) + + with io.BytesIO() as wav_stream: + # WAV header values + num_channels = 1 # Mono audio + bytes_per_sample = 2 # Assuming 16-bit audio + + # Write WAV header + wav_stream.write(b'RIFF') + wav_stream.write(b'\x00\x00\x00\x00') # Placeholder for file size + wav_stream.write(b'WAVE') + + # Write format chunk + wav_stream.write(b'fmt ') + wav_stream.write(struct.pack('){ + @A = split(" ", $_); + @A > 1 || die "Invalid line in spk2utt file: $_"; + $s = shift @A; + foreach $u ( @A ) { + print "$u $s\n"; + } +} + + diff --git a/wespeaker/bin/extract.py b/wespeaker/bin/extract.py index 4f493427..89d47c4c 100644 --- a/wespeaker/bin/extract.py +++ b/wespeaker/bin/extract.py @@ -54,6 +54,7 @@ def extract(config='conf/config.yaml', **kwargs): test_conf['mfcc_args']['dither'] = 0.0 test_conf['spec_aug'] = False test_conf['shuffle'] = False + test_conf['aug_prob'] = configs.get('aug_prob', 0.0) test_conf['filter'] = False dataset = Dataset(configs['data_type'], @@ -61,8 +62,8 @@ def extract(config='conf/config.yaml', **kwargs): test_conf, spk2id_dict={}, whole_utt=(batch_size == 1), - reverb_lmdb_file=None, - noise_lmdb_file=None, + reverb_lmdb_file=configs.get('reverb_data', None), + noise_lmdb_file=configs.get('noise_data', None), repeat_dataset=False) dataloader = DataLoader(dataset, shuffle=False, diff --git a/wespeaker/bin/train.py b/wespeaker/bin/train.py index 6b9b11f3..ca618301 100644 --- a/wespeaker/bin/train.py +++ b/wespeaker/bin/train.py @@ -100,7 +100,7 @@ def train(config='conf/config.yaml', **kwargs): if rank == 0: logger.info("<== Dataloaders ==>") logger.info("train dataloaders created") - logger.info('loader size: {}'.format(epoch_iter)) + logger.info('epoch iteration number: {}'.format(epoch_iter)) # model logger.info("<== Model ==>") diff --git a/wespeaker/dataset/processor.py b/wespeaker/dataset/processor.py index 9aa41564..676c02ab 100644 --- a/wespeaker/dataset/processor.py +++ b/wespeaker/dataset/processor.py @@ -125,6 +125,24 @@ def parse_raw(data): Returns: Iterable[{key, wav, spk, sample_rate}] """ + def read_audio(wav): + if wav.endswith('|'): + p = Popen(wav[:-1], shell=True, stdout=PIPE) + data = p.stdout.read() + waveform, sample_rate = torchaudio.load(io.BytesIO(data)) + else: + waveform, sample_rate = torchaudio.load(wav) + return waveform, sample_rate + + def apply_vad(waveform, sample_rate, vad): + voice_part_list = [] + for start, end in vad: + start, end = float(start), float(end) + start, end = int(start * sample_rate), int(end * sample_rate) + voice_part_list.append(waveform[:, start:end]) + waveform = torch.cat(voice_part_list, dim=1) + return waveform, sample_rate + for sample in data: assert 'src' in sample json_line = sample['src'] @@ -136,7 +154,9 @@ def parse_raw(data): wav_file = obj['wav'] spk = obj['spk'] try: - waveform, sample_rate = torchaudio.load(wav_file) + waveform, sample_rate = read_audio(wav_file) + if 'vad' in obj: + waveform, sample_rate = apply_vad(waveform, sample_rate, obj['vad']) example = dict(key=key, spk=spk, wav=waveform,