Skip to content

Commit

Permalink
[plda] Update the PLDA code (#186)
Browse files Browse the repository at this point in the history
* update the plda codes

* update the plda code to run.sh

* update readme

* reformat code

* fix lint error

* fix lint error

* update vox plda recipe

* update vox plda recipe

* Delete __init__.py

---------

Co-authored-by: wangshuai <[email protected]>
  • Loading branch information
wsstriving and wangshuai authored Jul 19, 2023
1 parent 2d52982 commit 0473c9c
Show file tree
Hide file tree
Showing 25 changed files with 715 additions and 190 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pip3 install wespeakerruntime
```

## 🔥 News
* 2023.07.18: Support the kaldi-compatible PLDA and unsupervised adaptation, see [#186](https://github.com/wenet-e2e/wespeaker/pull/186).
* 2023.07.14: Support the [NIST SRE16 recipe](https://www.nist.gov/itl/iad/mig/speaker-recognition-evaluation-2016), see [#177](https://github.com/wenet-e2e/wespeaker/pull/177).
* 2023.07.10: Support the [Self-Supervised Learning recipe](https://github.com/wenet-e2e/wespeaker/tree/master/examples/voxceleb/v3) on Voxceleb, including [DINO](https://openaccess.thecvf.com/content/ICCV2021/papers/Caron_Emerging_Properties_in_Self-Supervised_Vision_Transformers_ICCV_2021_paper.pdf), [MoCo](https://openaccess.thecvf.com/content_CVPR_2020/papers/He_Momentum_Contrast_for_Unsupervised_Visual_Representation_Learning_CVPR_2020_paper.pdf) and [SimCLR](http://proceedings.mlr.press/v119/chen20j/chen20j.pdf), see [#180](https://github.com/wenet-e2e/wespeaker/pull/180).

Expand Down
31 changes: 26 additions & 5 deletions ROADMAP.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,29 @@
# Wespeaker Roadmap

## Version 2.0 (Time: 2023.09)

This is the roadmap for wespeaker version 2.0.

- [ ] SSL support
- [ ] Algorithms
- [x] DINO
- [x] MOCO
- [x] SimCLR
- [ ] Iteratively psudo label prediction and supervised finetuning
- [ ] Recipes
- [x] VoxCeleb
- [ ] WenetSpeech

- [ ] Recipes
- [ ] 3D-speaker
- [ ] NIST SRE
- [x] SRE16
- [ ] SRE18

## Version 1.0 (Time: 2022.09)

This is the roadmap for wespeaker version 1.0.


- [x] Standard dataset support
- [x] VoxCeleb
- [x] CnCeleb
Expand All @@ -18,13 +38,14 @@ This is the roadmap for wespeaker version 1.0.
- [x] PLDA
- [x] UIO for effective industrial-scale dataset processing
- [x] Online data augmentation
- Noise && RIR
- Speed Perturb
- Specaug
- Noise && RIR
- Speed Perturb
- Specaug
- [x] ONNX support
- [x] Triton Server support (GPU)
- [ ] ~~
- Training or finetuning big models such as WavLM might be too costly for current stage
- Training or finetuning big models such as WavLM might be too costly for
current stage
- [x] Basic Speaker Diarization Recipe
- Embedding based (more related with our speaker embedding learner toolkit)
- [x] Interactive Demo
Expand Down
25 changes: 12 additions & 13 deletions examples/sre/v2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,17 @@
* Scoring: cosine & PLDA & PLDA Adaptation
* Metric: EER(%)

Without PLDA training data augmentation:
| Model | Params | Backend | Pooled | Tagalog | Cantonese |
|:------|:------:|:------------:|:------------:|:------------:|:------------:|
| ResNet34-TSTP-emb256 | 6.63M | Cosine | 15.4 | 19.82 | 10.39 |
| | | PLDA | 9.36 | 14.26 | 4.513 |
| | | Adapt PLDA | 6.608 | 10.01 | 2.974 |
| Model | Params | Backend | Pooled | Tagalog | Cantonese |
|:---------------------|:------:|:----------:|:------:|:-------:|:---------:|
| ResNet34-TSTP-emb256 | 6.63M | Cosine | 15.4 | 19.82 | 10.39 |
| | | PLDA | 11.689 | 16.961 | 6.239 |
| | | Adapt PLDA | 5.788 | 8.974 | 2.674 |

With PLDA training data augmentation:
| Model | Params | Backend | Pooled | Tagalog | Cantonese |
|:------|:------:|:------------:|:------------:|:------------:|:------------:|
| ResNet34-TSTP-emb256 | 6.63M | Cosine | 15.4 | 19.82 | 10.39 |
| | | PLDA | 8.944 | 13.54 | 4.462 |
| | | Adapt PLDA | 6.543 | 9.666 | 3.254 |
Current PLDA implementation is fully compatible with the Kaldi version, note that
we can definitely improve the results with out adaptation with parameter tuning and extra LDA as shown in the Kaldi
Recipe, we didn't do this because we focus more on the adapted results, which are good enough under current setup.

* 🔥 UPDATE 2023.07.14: Support the [NIST SRE16 recipe](https://www.nist.gov/itl/iad/mig/speaker-recognition-evaluation-2016), see [#177](https://github.com/wenet-e2e/wespeaker/pull/177).
* 🔥 UPDATE 2023.07.18: Support kaldi-compatible two-covariance PLDA and unsupervised domain adaptation.
* 🔥 UPDATE 2023.07.14: Support
the [NIST SRE16 recipe](https://www.nist.gov/itl/iad/mig/speaker-recognition-evaluation-2016),
see [#177](https://github.com/wenet-e2e/wespeaker/pull/177).
87 changes: 87 additions & 0 deletions examples/sre/v2/local/score_plda.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#!/bin/bash

# Copyright (c) 2023 Shuai Wang ([email protected])
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
exp_dir=
trials="trials trials_tgl trials_yue"
data=data
aug_plda_data=0

stage=-1
stop_stage=-1

. tools/parse_options.sh
. path.sh

if [ $aug_plda_data = 0 ];then
sre_plda_data=sre
else
sre_plda_data=sre_aug
fi

if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "train the plda model ..."
python wespeaker/bin/train_plda.py \
--exp_dir ${exp_dir} \
--scp_path ${exp_dir}/embeddings/${sre_plda_data}/xvector.scp \
--utt2spk ${data}/${sre_plda_data}/utt2spk \
--indim 256 \
--iter 10
echo "plda training finished"
fi

if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "adapt the plda model ..."
python wespeaker/bin/adapt_plda.py \
-mo ${exp_dir}/plda \
-ma ${exp_dir}/plda_adapt \
-ad ${exp_dir}/embeddings/sre16_major/xvector.scp \
-ws 0.75 \
-as 0.25
echo "plda adapted finished"
fi

if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "apply plda scoring ..."
mkdir -p ${exp_dir}/scores
trials_dir=${data}/trials
for x in $trials; do
echo "scoring on " $x
python wespeaker/bin/eval_plda.py \
--enroll_scp_path ${exp_dir}/embeddings/sre16_eval_enroll/xvector.scp \
--test_scp_path ${exp_dir}/embeddings/sre16_eval_test/xvector.scp \
--indomain_scp ${exp_dir}/embeddings/sre16_major/xvector.scp \
--utt2spk data/sre16_eval_enroll/utt2spk \
--trial ${trials_dir}/${x} \
--score_path ${exp_dir}/scores/${x}.pldascore \
--model_path ${exp_dir}/plda_adapt
done
fi

if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
echo "compute metrics (EER/minDCF) ..."
scores_dir=${exp_dir}/scores
for x in $trials; do
python wespeaker/bin/compute_metrics.py \
--p_target 0.01 \
--c_fa 1 \
--c_miss 1 \
${scores_dir}/${x}.pldascore \
2>&1 | tee -a ${scores_dir}/sre16_plda_result

echo "compute DET curve ..."
python wespeaker/bin/compute_det.py \
${scores_dir}/${x}.pldascore
done
fi
12 changes: 11 additions & 1 deletion examples/sre/v2/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
fi

if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
echo "Score ..."
echo "Score using Cosine Distance..."
local/score.sh \
--stage 1 --stop-stage 2 \
--data ${data} \
Expand All @@ -115,6 +115,16 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
fi

if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
echo "Score with adapted PLDA ..."
local/score_plda.sh \
--stage 1 --stop-stage 4 \
--data ${data} \
--exp_dir $exp_dir \
--aug_plda_data ${aug_plda_data} \
--trials "$trials"
fi

if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
echo "Export the best model ..."
python wespeaker/bin/export_jit.py \
--config $exp_dir/config.yaml \
Expand Down
9 changes: 4 additions & 5 deletions examples/voxceleb/v2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,8 @@ If you are interested in the PLDA scoring (which is inferior to the simple cosin
local/score_plda.sh --stage 1 --stop-stage 3 --exp_dir exp_name
```

The results on ResNet293 (large margin, no asnorm) are:
The results on ResNet34 (large margin, no asnorm) are:

|Scoring method| vox1-O-clean | vox1-E-clean | vox1-H-clean |
| :---:|:------------:|:------------:|:------------:|
|cosine| 0.532 | 0.707 | 1.311 |
|plda | 0.744 | 0.794 | 1.374|
| Scoring method | vox1-O-clean | vox1-E-clean | vox1-H-clean |
|:--------------:|:------------:|:------------:|:------------:|
| PLDA | 1.207 | 1.350 | 2.528 |
12 changes: 6 additions & 6 deletions examples/voxceleb/v2/local/score_plda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,29 +26,29 @@ stop_stage=-1

if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "train the plda model ..."
mkdir -p ${exp_dir}/scores
python wespeaker/bin/train_plda.py \
--exp_dir ${exp_dir} \
--scp_path ${exp_dir}/embeddings/vox2_dev/xvector.scp \
--utt2spk ${data}/vox2_dev/utt2spk \
--indim 256 \
--iter 5 \
--type '2cov'
--iter 5
echo "plda training finished"
fi


if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "apply plda scoring ..."
mkdir -p ${exp_dir}/scores
trials_dir=${data}/vox1/trials
for x in $trials; do
echo $x
echo "scoring on " $x
python wespeaker/bin/eval_plda.py \
--exp_dir ${exp_dir} \
--enroll_scp_path ${exp_dir}/embeddings/vox1/xvector.scp \
--test_scp_path ${exp_dir}/embeddings/vox1/xvector.scp \
--utt2spk <(cat ${data}/vox1/utt2spk | awk '{print $1, $1}') \
--trial ${trials_dir}/${x}
--trial ${trials_dir}/${x} \
--score_path ${exp_dir}/scores/${x}.pldascore \
--model_path ${exp_dir}/plda
done
fi

Expand Down
58 changes: 58 additions & 0 deletions wespeaker/bin/adapt_plda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) 2023 Brno University of Technology
# Shuai Wang ([email protected])
#
# Python implementation of Kaldi unsupervised PLDA adaptation
# ( https://github.com/kaldi-asr/kaldi/blob/master/src/ivector/plda.cc#L613 )
# by Daniel Povey.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse

from wespeaker.utils.plda.two_cov_plda import TwoCovPLDA

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--adp_scp', '-ad',
type=str,
required=True,
help='Data for unlabeled adaptation.')
parser.add_argument('--across_class_scale', '-as',
type=float,
help='Scaling factor for across class covariance.',
default=0.5)
parser.add_argument('--within_class_scale', '-ws',
type=float,
help='Scaling factor for withn class covariance.',
default=0.5)
parser.add_argument('--mdl_org', '-mo',
type=str,
required=True,
help='Original PLDA mdl.')
parser.add_argument('--mdl_adp', '-ma',
type=str,
required=True,
help='Adapted PLDA mdl.')
parser.add_argument('--mdl_format', '-mf',
type=str,
default='wespeaker',
help='Format of the model wespeaker/kaldi')

args = parser.parse_args()

kaldi_format = True if args.mdl_format == 'kaldi' else False
plda = TwoCovPLDA.load_model(args.mdl_org, kaldi_format)
adapt_plda = plda.adapt(args.adp_scp, args.across_class_scale,
args.within_class_scale)
adapt_plda.save_model(args.mdl_adp)
3 changes: 2 additions & 1 deletion wespeaker/bin/average_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def get_args():
def main():
args = get_args()

path_list = glob.glob('{}/[!avg][!final][!convert]*.pt'.format(args.src_path))
path_list = glob.glob(
'{}/[!avg][!final][!convert]*.pt'.format(args.src_path))
path_list = sorted(
path_list,
key=lambda p: int(re.findall(r"(?<=model_)\d*(?=.pt)", p)[0]))
Expand Down
4 changes: 3 additions & 1 deletion wespeaker/bin/compute_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
# limitations under the License.

import os
import numpy as np

import fire
import numpy as np

from wespeaker.utils.score_metrics import (compute_pmiss_pfa_rbst, compute_eer,
compute_c_norm)

Expand Down
29 changes: 15 additions & 14 deletions wespeaker/bin/eval_plda.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,27 @@
# limitations under the License.

import argparse
import os

from wespeaker.utils.plda.two_cov_plda import TwoCovPLDA

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--type',
type=str,
default='2cov',
help='which type of plda to use')
parser.add_argument('--enroll_scp_path', type=str)
parser.add_argument('--test_scp_path', type=str)
parser.add_argument('--utt2spk', type=str)
parser.add_argument('--exp_dir', type=str)
parser.add_argument('--trial', type=str)
help='which type of plda to use, 2cov|kaldi')
parser.add_argument('--enroll_scp_path', type=str, help='enroll embeddings')
parser.add_argument('--indomain_scp_path', type=str,
help='embeddings to compute meanvec')
parser.add_argument('--test_scp_path', type=str, help='test embeddings')
parser.add_argument('--utt2spk', type=str,
help='utt2spk for the enroll speakers')
parser.add_argument('--model_path', type=str, help='pretrained plda path')
parser.add_argument('--score_path', type=str, help='score file to write to')
parser.add_argument('--trial', type=str, help='trial file to score upon')
args = parser.parse_args()

if args.type == '2cov':
model_path = os.path.join(args.exp_dir, '2cov.plda')
score_path = os.path.join(args.exp_dir, 'scores',
os.path.basename(args.trial) + '.pldascore')
plda = TwoCovPLDA.load_model(model_path)
plda.eval_sv(args.enroll_scp_path, args.utt2spk, args.test_scp_path,
args.trial, score_path)
kaldi_format = True if args.type == 'kaldi' else False
plda = TwoCovPLDA.load_model(args.model_path, kaldi_format)
plda.eval_sv(args.enroll_scp_path, args.utt2spk, args.test_scp_path,
args.trial, args.score_path, args.indomain_scp_path)
2 changes: 1 addition & 1 deletion wespeaker/bin/export_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import torch
import yaml

from wespeaker.utils.checkpoint import load_checkpoint
from wespeaker.models.speaker_model import get_speaker_model
from wespeaker.utils.checkpoint import load_checkpoint


def get_args():
Expand Down
Loading

0 comments on commit 0473c9c

Please sign in to comment.