From 59407edcad3a4a26342cee7dc7f0fac6d1ff50b4 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 9 Oct 2024 12:01:20 +0800 Subject: [PATCH] C++ API for speaker diarization (#1396) --- .github/scripts/test-speaker-diarization.sh | 41 ++ .../export-pyannote-segmentation-to-onnx.yaml | 2 +- .github/workflows/linux.yaml | 11 + .github/workflows/macos.yaml | 11 + .github/workflows/speaker-diarization.yaml | 2 +- .github/workflows/windows-x64.yaml | 11 + .github/workflows/windows-x86.yaml | 11 + ...s-spotter-buffered-tokens-keywords-c-api.c | 2 +- .../streaming-ctc-buffered-tokens-c-api.c | 2 +- ...reaming-paraformer-buffered-tokens-c-api.c | 2 +- ...zipformer-buffered-tokens-hotwords-c-api.c | 2 +- cmake/cmake_extension.py | 1 + scripts/pyannote/segmentation/README.md | 9 +- scripts/pyannote/segmentation/export-onnx.py | 2 +- sherpa-onnx/csrc/CMakeLists.txt | 16 + sherpa-onnx/csrc/fast-clustering-config.cc | 22 +- sherpa-onnx/csrc/macros.h | 3 + sherpa-onnx/csrc/offline-sense-voice-model.cc | 1 + .../csrc/offline-speaker-diarization-impl.cc | 26 + .../csrc/offline-speaker-diarization-impl.h | 31 + ...ffline-speaker-diarization-pyannote-impl.h | 644 ++++++++++++++++++ .../offline-speaker-diarization-result.cc | 110 +++ .../csrc/offline-speaker-diarization-result.h | 65 ++ .../csrc/offline-speaker-diarization.cc | 79 +++ .../csrc/offline-speaker-diarization.h | 73 ++ ...fline-speaker-segmentation-model-config.cc | 57 ++ ...ffline-speaker-segmentation-model-config.h | 40 ++ ...aker-segmentation-pyannote-model-config.cc | 38 ++ ...eaker-segmentation-pyannote-model-config.h | 30 + ...er-segmentation-pyannote-model-meta-data.h | 29 + ...ine-speaker-segmentation-pyannote-model.cc | 108 +++ ...line-speaker-segmentation-pyannote-model.h | 40 ++ sherpa-onnx/csrc/provider-config.cc | 6 +- sherpa-onnx/csrc/session.cc | 42 +- sherpa-onnx/csrc/session.h | 45 +- ...sherpa-onnx-offline-speaker-diarization.cc | 133 ++++ sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc | 7 +- .../csrc/sherpa-onnx-online-punctuation.cc | 2 +- .../csrc/speaker-embedding-extractor.cc | 4 +- 39 files changed, 1652 insertions(+), 108 deletions(-) create mode 100755 .github/scripts/test-speaker-diarization.sh create mode 100644 sherpa-onnx/csrc/offline-speaker-diarization-impl.cc create mode 100644 sherpa-onnx/csrc/offline-speaker-diarization-impl.h create mode 100644 sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h create mode 100644 sherpa-onnx/csrc/offline-speaker-diarization-result.cc create mode 100644 sherpa-onnx/csrc/offline-speaker-diarization-result.h create mode 100644 sherpa-onnx/csrc/offline-speaker-diarization.cc create mode 100644 sherpa-onnx/csrc/offline-speaker-diarization.h create mode 100644 sherpa-onnx/csrc/offline-speaker-segmentation-model-config.cc create mode 100644 sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h create mode 100644 sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.cc create mode 100644 sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h create mode 100644 sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h create mode 100644 sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc create mode 100644 sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h create mode 100644 sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc diff --git a/.github/scripts/test-speaker-diarization.sh b/.github/scripts/test-speaker-diarization.sh new file mode 100755 index 000000000..6d7b2effd --- /dev/null +++ b/.github/scripts/test-speaker-diarization.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + +set -ex + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +echo "EXE is $EXE" +echo "PATH: $PATH" + +which $EXE + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 +tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 +rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav + +log "specify number of clusters" +$EXE \ + --clustering.num-clusters=4 \ + --segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \ + --embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \ + ./0-four-speakers-zh.wav + +log "specify threshold for clustering" + +$EXE \ + --clustering.cluster-threshold=0.90 \ + --segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \ + --embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \ + ./0-four-speakers-zh.wav + +rm -rf sherpa-onnx-pyannote-* +rm -fv *.onnx +rm -fv *.wav diff --git a/.github/workflows/export-pyannote-segmentation-to-onnx.yaml b/.github/workflows/export-pyannote-segmentation-to-onnx.yaml index 300aca500..ece0ffa28 100644 --- a/.github/workflows/export-pyannote-segmentation-to-onnx.yaml +++ b/.github/workflows/export-pyannote-segmentation-to-onnx.yaml @@ -29,7 +29,7 @@ jobs: - name: Install pyannote shell: bash run: | - pip install pyannote.audio onnx onnxruntime + pip install pyannote.audio onnx==1.15.0 onnxruntime==1.16.3 - name: Run shell: bash diff --git a/.github/workflows/linux.yaml b/.github/workflows/linux.yaml index 0e1eca099..1d3e8dc7b 100644 --- a/.github/workflows/linux.yaml +++ b/.github/workflows/linux.yaml @@ -18,6 +18,7 @@ on: - '.github/scripts/test-audio-tagging.sh' - '.github/scripts/test-offline-punctuation.sh' - '.github/scripts/test-online-punctuation.sh' + - '.github/scripts/test-speaker-diarization.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -38,6 +39,7 @@ on: - '.github/scripts/test-audio-tagging.sh' - '.github/scripts/test-offline-punctuation.sh' - '.github/scripts/test-online-punctuation.sh' + - '.github/scripts/test-speaker-diarization.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -143,6 +145,15 @@ jobs: name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} path: install/* + - name: Test offline speaker diarization + shell: bash + run: | + du -h -d1 . + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline-speaker-diarization + + .github/scripts/test-speaker-diarization.sh + - name: Test offline transducer shell: bash run: | diff --git a/.github/workflows/macos.yaml b/.github/workflows/macos.yaml index 084531e4a..f3d70f583 100644 --- a/.github/workflows/macos.yaml +++ b/.github/workflows/macos.yaml @@ -18,6 +18,7 @@ on: - '.github/scripts/test-audio-tagging.sh' - '.github/scripts/test-offline-punctuation.sh' - '.github/scripts/test-online-punctuation.sh' + - '.github/scripts/test-speaker-diarization.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -37,6 +38,7 @@ on: - '.github/scripts/test-audio-tagging.sh' - '.github/scripts/test-offline-punctuation.sh' - '.github/scripts/test-online-punctuation.sh' + - '.github/scripts/test-speaker-diarization.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -115,6 +117,15 @@ jobs: otool -L build/bin/sherpa-onnx otool -l build/bin/sherpa-onnx + - name: Test offline speaker diarization + shell: bash + run: | + du -h -d1 . + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline-speaker-diarization + + .github/scripts/test-speaker-diarization.sh + - name: Test offline transducer shell: bash run: | diff --git a/.github/workflows/speaker-diarization.yaml b/.github/workflows/speaker-diarization.yaml index 0bd6a575c..ab2a4f090 100644 --- a/.github/workflows/speaker-diarization.yaml +++ b/.github/workflows/speaker-diarization.yaml @@ -67,7 +67,7 @@ jobs: curl -SL -O https://huggingface.co/csukuangfj/pyannote-models/resolve/main/segmentation-3.0/pytorch_model.bin test_wavs=( - 0-two-speakers-zh.wav + 0-four-speakers-zh.wav 1-two-speakers-en.wav 2-two-speakers-en.wav 3-two-speakers-en.wav diff --git a/.github/workflows/windows-x64.yaml b/.github/workflows/windows-x64.yaml index 2d2811c31..c67f3e0b5 100644 --- a/.github/workflows/windows-x64.yaml +++ b/.github/workflows/windows-x64.yaml @@ -17,6 +17,7 @@ on: - '.github/scripts/test-audio-tagging.sh' - '.github/scripts/test-offline-punctuation.sh' - '.github/scripts/test-online-punctuation.sh' + - '.github/scripts/test-speaker-diarization.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -34,6 +35,7 @@ on: - '.github/scripts/test-audio-tagging.sh' - '.github/scripts/test-offline-punctuation.sh' - '.github/scripts/test-online-punctuation.sh' + - '.github/scripts/test-speaker-diarization.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -87,6 +89,15 @@ jobs: name: release-windows-x64-${{ matrix.shared_lib }}-${{ matrix.with_tts }} path: build/install/* + - name: Test offline speaker diarization + shell: bash + run: | + du -h -d1 . + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline-speaker-diarization.exe + + .github/scripts/test-speaker-diarization.sh + - name: Test online punctuation shell: bash run: | diff --git a/.github/workflows/windows-x86.yaml b/.github/workflows/windows-x86.yaml index 316cef626..30394e90e 100644 --- a/.github/workflows/windows-x86.yaml +++ b/.github/workflows/windows-x86.yaml @@ -17,6 +17,7 @@ on: - '.github/scripts/test-audio-tagging.sh' - '.github/scripts/test-offline-punctuation.sh' - '.github/scripts/test-online-punctuation.sh' + - '.github/scripts/test-speaker-diarization.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -34,6 +35,7 @@ on: - '.github/scripts/test-audio-tagging.sh' - '.github/scripts/test-offline-punctuation.sh' - '.github/scripts/test-online-punctuation.sh' + - '.github/scripts/test-speaker-diarization.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -87,6 +89,15 @@ jobs: name: release-windows-x86-${{ matrix.shared_lib }}-${{ matrix.with_tts }} path: build/install/* + - name: Test offline speaker diarization + shell: bash + run: | + du -h -d1 . + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline-speaker-diarization.exe + + .github/scripts/test-speaker-diarization.sh + - name: Test online punctuation shell: bash run: | diff --git a/c-api-examples/keywords-spotter-buffered-tokens-keywords-c-api.c b/c-api-examples/keywords-spotter-buffered-tokens-keywords-c-api.c index ec8be3b07..45a0bb87a 100644 --- a/c-api-examples/keywords-spotter-buffered-tokens-keywords-c-api.c +++ b/c-api-examples/keywords-spotter-buffered-tokens-keywords-c-api.c @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) { fprintf(stderr, "Memory error\n"); return -1; } - size_t read_bytes = fread(*buffer_out, 1, size, file); + size_t read_bytes = fread((void *)*buffer_out, 1, size, file); if (read_bytes != size) { printf("Errors occured in reading the file %s\n", filename); free((void *)*buffer_out); diff --git a/c-api-examples/streaming-ctc-buffered-tokens-c-api.c b/c-api-examples/streaming-ctc-buffered-tokens-c-api.c index 3223772a8..33690e008 100644 --- a/c-api-examples/streaming-ctc-buffered-tokens-c-api.c +++ b/c-api-examples/streaming-ctc-buffered-tokens-c-api.c @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) { fprintf(stderr, "Memory error\n"); return -1; } - size_t read_bytes = fread(*buffer_out, 1, size, file); + size_t read_bytes = fread((void *)*buffer_out, 1, size, file); if (read_bytes != size) { printf("Errors occured in reading the file %s\n", filename); free((void *)*buffer_out); diff --git a/c-api-examples/streaming-paraformer-buffered-tokens-c-api.c b/c-api-examples/streaming-paraformer-buffered-tokens-c-api.c index cd87177b5..a597374df 100644 --- a/c-api-examples/streaming-paraformer-buffered-tokens-c-api.c +++ b/c-api-examples/streaming-paraformer-buffered-tokens-c-api.c @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) { fprintf(stderr, "Memory error\n"); return -1; } - size_t read_bytes = fread(*buffer_out, 1, size, file); + size_t read_bytes = fread((void *)*buffer_out, 1, size, file); if (read_bytes != size) { printf("Errors occured in reading the file %s\n", filename); free((void *)*buffer_out); diff --git a/c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c b/c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c index d5092c5cc..c991d4999 100644 --- a/c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c +++ b/c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) { fprintf(stderr, "Memory error\n"); return -1; } - size_t read_bytes = fread(*buffer_out, 1, size, file); + size_t read_bytes = fread((void *)*buffer_out, 1, size, file); if (read_bytes != size) { printf("Errors occured in reading the file %s\n", filename); free((void *)*buffer_out); diff --git a/cmake/cmake_extension.py b/cmake/cmake_extension.py index 672e3d17a..c49c32555 100644 --- a/cmake/cmake_extension.py +++ b/cmake/cmake_extension.py @@ -55,6 +55,7 @@ def get_binaries(): "sherpa-onnx-offline-audio-tagging", "sherpa-onnx-offline-language-identification", "sherpa-onnx-offline-punctuation", + "sherpa-onnx-offline-speaker-diarization", "sherpa-onnx-offline-tts", "sherpa-onnx-offline-tts-play", "sherpa-onnx-offline-websocket-server", diff --git a/scripts/pyannote/segmentation/README.md b/scripts/pyannote/segmentation/README.md index a2e35b2de..a9c5230d1 100644 --- a/scripts/pyannote/segmentation/README.md +++ b/scripts/pyannote/segmentation/README.md @@ -3,12 +3,9 @@ Please download test wave files from https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models -## 0-two-speakers-zh.wav +## 0-four-speakers-zh.wav -This file is from -https://www.modelscope.cn/models/iic/speech_campplus_speaker-diarization_common/file/view/master?fileName=examples%252F2speakers_example.wav&status=0 - -Note that we have renamed it from `2speakers_example.wav` to `0-two-speakers-zh.wav`. +It is recorded by @csukuangfj ## 1-two-speakers-en.wav @@ -40,5 +37,5 @@ commands to convert it to `3-two-speakers-en.wav` ```bash -sox ML16091-Audio.mp3 3-two-speakers-en.wav +sox ML16091-Audio.mp3 -r 16k 3-two-speakers-en.wav ``` diff --git a/scripts/pyannote/segmentation/export-onnx.py b/scripts/pyannote/segmentation/export-onnx.py index 5f6e79c7e..feb241a26 100755 --- a/scripts/pyannote/segmentation/export-onnx.py +++ b/scripts/pyannote/segmentation/export-onnx.py @@ -72,7 +72,7 @@ def main(): model.receptive_field.duration * 16000 ) - opset_version = 18 + opset_version = 13 filename = "model.onnx" torch.onnx.export( diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index e49fdeed4..3e6526563 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -164,6 +164,12 @@ if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) list(APPEND sources fast-clustering-config.cc fast-clustering.cc + offline-speaker-diarization-impl.cc + offline-speaker-diarization-result.cc + offline-speaker-diarization.cc + offline-speaker-segmentation-model-config.cc + offline-speaker-segmentation-pyannote-model-config.cc + offline-speaker-segmentation-pyannote-model.cc ) endif() @@ -260,6 +266,10 @@ if(SHERPA_ONNX_ENABLE_BINARY) add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) endif() + if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) + add_executable(sherpa-onnx-offline-speaker-diarization sherpa-onnx-offline-speaker-diarization.cc) + endif() + set(main_exes sherpa-onnx sherpa-onnx-keyword-spotter @@ -276,6 +286,12 @@ if(SHERPA_ONNX_ENABLE_BINARY) ) endif() + if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) + list(APPEND main_exes + sherpa-onnx-offline-speaker-diarization + ) + endif() + foreach(exe IN LISTS main_exes) target_link_libraries(${exe} sherpa-onnx-core) endforeach() diff --git a/sherpa-onnx/csrc/fast-clustering-config.cc b/sherpa-onnx/csrc/fast-clustering-config.cc index e8382e598..e4f64fbbb 100644 --- a/sherpa-onnx/csrc/fast-clustering-config.cc +++ b/sherpa-onnx/csrc/fast-clustering-config.cc @@ -21,18 +21,16 @@ std::string FastClusteringConfig::ToString() const { } void FastClusteringConfig::Register(ParseOptions *po) { - std::string prefix = "ctc"; - ParseOptions p(prefix, po); - - p.Register("num-clusters", &num_clusters, - "Number of cluster. If greater than 0, then --cluster-thresold is " - "ignored. Please provide it if you know the actual number of " - "clusters in advance."); - - p.Register("cluster-threshold", &threshold, - "If --num-clusters is not specified, then it specifies the " - "distance threshold for clustering. smaller value -> more " - "clusters. larger value -> fewer clusters"); + po->Register( + "num-clusters", &num_clusters, + "Number of cluster. If greater than 0, then cluster threshold is " + "ignored. Please provide it if you know the actual number of " + "clusters in advance."); + + po->Register("cluster-threshold", &threshold, + "If num_clusters is not specified, then it specifies the " + "distance threshold for clustering. smaller value -> more " + "clusters. larger value -> fewer clusters"); } bool FastClusteringConfig::Validate() const { diff --git a/sherpa-onnx/csrc/macros.h b/sherpa-onnx/csrc/macros.h index b5dfb99e3..6bd6f62a6 100644 --- a/sherpa-onnx/csrc/macros.h +++ b/sherpa-onnx/csrc/macros.h @@ -5,6 +5,7 @@ #ifndef SHERPA_ONNX_CSRC_MACROS_H_ #define SHERPA_ONNX_CSRC_MACROS_H_ #include +#include #if __ANDROID_API__ >= 8 #include "android/log.h" @@ -169,4 +170,6 @@ } \ } while (0) +#define SHERPA_ONNX_EXIT(code) exit(code) + #endif // SHERPA_ONNX_CSRC_MACROS_H_ diff --git a/sherpa-onnx/csrc/offline-sense-voice-model.cc b/sherpa-onnx/csrc/offline-sense-voice-model.cc index 1d2a14ef5..24903a41a 100644 --- a/sherpa-onnx/csrc/offline-sense-voice-model.cc +++ b/sherpa-onnx/csrc/offline-sense-voice-model.cc @@ -9,6 +9,7 @@ #include #include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" #include "sherpa-onnx/csrc/text-utils.h" diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-impl.cc b/sherpa-onnx/csrc/offline-speaker-diarization-impl.cc new file mode 100644 index 000000000..e41a7767a --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-diarization-impl.cc @@ -0,0 +1,26 @@ +// sherpa-onnx/csrc/offline-speaker-diarization-impl.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h" + +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h" + +namespace sherpa_onnx { + +std::unique_ptr +OfflineSpeakerDiarizationImpl::Create( + const OfflineSpeakerDiarizationConfig &config) { + if (!config.segmentation.pyannote.model.empty()) { + return std::make_unique(config); + } + + SHERPA_ONNX_LOGE("Please specify a speaker segmentation model."); + + return nullptr; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-impl.h b/sherpa-onnx/csrc/offline-speaker-diarization-impl.h new file mode 100644 index 000000000..f7fe39499 --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-diarization-impl.h @@ -0,0 +1,31 @@ +// sherpa-onnx/csrc/offline-speaker-diarization-impl.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_ + +#include +#include + +#include "sherpa-onnx/csrc/offline-speaker-diarization.h" +namespace sherpa_onnx { + +class OfflineSpeakerDiarizationImpl { + public: + static std::unique_ptr Create( + const OfflineSpeakerDiarizationConfig &config); + + virtual ~OfflineSpeakerDiarizationImpl() = default; + + virtual int32_t SampleRate() const = 0; + + virtual OfflineSpeakerDiarizationResult Process( + const float *audio, int32_t n, + OfflineSpeakerDiarizationProgressCallback callback = nullptr, + void *callback_arg = nullptr) const = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_ diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h new file mode 100644 index 000000000..bcd0c93a4 --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h @@ -0,0 +1,644 @@ +// sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ + +#include +#include +#include +#include + +#include "Eigen/Dense" +#include "sherpa-onnx/csrc/fast-clustering.h" +#include "sherpa-onnx/csrc/math.h" +#include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h" +#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h" +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h" + +namespace sherpa_onnx { + +namespace { // NOLINT + +// copied from https://github.com/k2-fsa/k2/blob/master/k2/csrc/host/util.h#L41 +template +inline void hash_combine(std::size_t *seed, const T &v) { // NOLINT + std::hash hasher; + *seed ^= hasher(v) + 0x9e3779b9 + ((*seed) << 6) + ((*seed) >> 2); // NOLINT +} + +// copied from https://github.com/k2-fsa/k2/blob/master/k2/csrc/host/util.h#L47 +struct PairHash { + template + std::size_t operator()(const std::pair &pair) const { + std::size_t result = 0; + hash_combine(&result, pair.first); + hash_combine(&result, pair.second); + return result; + } +}; +} // namespace + +using Matrix2D = + Eigen::Matrix; + +using Matrix2DInt32 = + Eigen::Matrix; + +using FloatRowVector = Eigen::Matrix; +using Int32RowVector = Eigen::Matrix; + +using Int32Pair = std::pair; + +class OfflineSpeakerDiarizationPyannoteImpl + : public OfflineSpeakerDiarizationImpl { + public: + ~OfflineSpeakerDiarizationPyannoteImpl() override = default; + + explicit OfflineSpeakerDiarizationPyannoteImpl( + const OfflineSpeakerDiarizationConfig &config) + : config_(config), + segmentation_model_(config_.segmentation), + embedding_extractor_(config_.embedding), + clustering_(config_.clustering) { + Init(); + } + + int32_t SampleRate() const override { + const auto &meta_data = segmentation_model_.GetModelMetaData(); + + return meta_data.sample_rate; + } + + OfflineSpeakerDiarizationResult Process( + const float *audio, int32_t n, + OfflineSpeakerDiarizationProgressCallback callback = nullptr, + void *callback_arg = nullptr) const override { + std::vector segmentations = RunSpeakerSegmentationModel(audio, n); + // segmentations[i] is for chunk_i + // Each matrix is of shape (num_frames, num_powerset_classes) + if (segmentations.empty()) { + return {}; + } + + std::vector labels; + labels.reserve(segmentations.size()); + + for (const auto &m : segmentations) { + labels.push_back(ToMultiLabel(m)); + } + + segmentations.clear(); + + // labels[i] is a 0-1 matrix of shape (num_frames, num_speakers) + + // speaker count per frame + Int32RowVector speakers_per_frame = ComputeSpeakersPerFrame(labels); + + if (speakers_per_frame.maxCoeff() == 0) { + SHERPA_ONNX_LOGE("No speakers found in the audio samples"); + return {}; + } + + auto chunk_speaker_samples_list_pair = GetChunkSpeakerSampleIndexes(labels); + Matrix2D embeddings = + ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second, + callback, callback_arg); + + std::vector cluster_labels = clustering_.Cluster( + &embeddings(0, 0), embeddings.rows(), embeddings.cols()); + + int32_t max_cluster_index = + *std::max_element(cluster_labels.begin(), cluster_labels.end()); + + auto chunk_speaker_to_cluster = ConvertChunkSpeakerToCluster( + chunk_speaker_samples_list_pair.first, cluster_labels); + + auto new_labels = + ReLabel(labels, max_cluster_index, chunk_speaker_to_cluster); + + Matrix2DInt32 speaker_count = ComputeSpeakerCount(new_labels, n); + + Matrix2DInt32 final_labels = + FinalizeLabels(speaker_count, speakers_per_frame); + + auto result = ComputeResult(final_labels); + + return result; + } + + private: + void Init() { InitPowersetMapping(); } + + // see also + // https://github.com/pyannote/pyannote-audio/blob/develop/pyannote/audio/utils/powerset.py#L68 + void InitPowersetMapping() { + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t num_classes = meta_data.num_classes; + int32_t powerset_max_classes = meta_data.powerset_max_classes; + int32_t num_speakers = meta_data.num_speakers; + + powerset_mapping_ = Matrix2DInt32(num_classes, num_speakers); + powerset_mapping_.setZero(); + + int32_t k = 1; + for (int32_t i = 1; i <= powerset_max_classes; ++i) { + if (i == 1) { + for (int32_t j = 0; j != num_speakers; ++j, ++k) { + powerset_mapping_(k, j) = 1; + } + } else if (i == 2) { + for (int32_t j = 0; j != num_speakers; ++j) { + for (int32_t m = j + 1; m < num_speakers; ++m, ++k) { + powerset_mapping_(k, j) = 1; + powerset_mapping_(k, m) = 1; + } + } + } else { + SHERPA_ONNX_LOGE( + "powerset_max_classes = %d is currently not supported!", i); + SHERPA_ONNX_EXIT(-1); + } + } + } + + std::vector RunSpeakerSegmentationModel(const float *audio, + int32_t n) const { + std::vector ans; + + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t window_size = meta_data.window_size; + int32_t window_shift = meta_data.window_shift; + + if (n <= 0) { + SHERPA_ONNX_LOGE( + "number of audio samples is %d (<= 0). Please provide a positive " + "number", + n); + return {}; + } + + if (n <= window_size) { + std::vector buf(window_size); + // NOTE: buf is zero initialized by default + + std::copy(audio, audio + n, buf.data()); + + Matrix2D m = ProcessChunk(buf.data()); + + ans.push_back(std::move(m)); + + return ans; + } + + int32_t num_chunks = (n - window_size) / window_shift + 1; + bool has_last_chunk = (n - window_size) % window_shift > 0; + + ans.reserve(num_chunks + has_last_chunk); + + const float *p = audio; + + for (int32_t i = 0; i != num_chunks; ++i, p += window_shift) { + Matrix2D m = ProcessChunk(p); + + ans.push_back(std::move(m)); + } + + if (has_last_chunk) { + std::vector buf(window_size); + std::copy(p, audio + n, buf.data()); + + Matrix2D m = ProcessChunk(buf.data()); + + ans.push_back(std::move(m)); + } + + return ans; + } + + Matrix2D ProcessChunk(const float *p) const { + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t window_size = meta_data.window_size; + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::array shape = {1, 1, window_size}; + + Ort::Value x = + Ort::Value::CreateTensor(memory_info, const_cast(p), + window_size, shape.data(), shape.size()); + + Ort::Value out = segmentation_model_.Forward(std::move(x)); + std::vector out_shape = out.GetTensorTypeAndShapeInfo().GetShape(); + Matrix2D m(out_shape[1], out_shape[2]); + std::copy(out.GetTensorData(), out.GetTensorData() + m.size(), + &m(0, 0)); + return m; + } + + Matrix2DInt32 ToMultiLabel(const Matrix2D &m) const { + int32_t num_rows = m.rows(); + Matrix2DInt32 ans(num_rows, powerset_mapping_.cols()); + + std::ptrdiff_t col_id; + + for (int32_t i = 0; i != num_rows; ++i) { + m.row(i).maxCoeff(&col_id); + ans.row(i) = powerset_mapping_.row(col_id); + } + + return ans; + } + + // See also + // https://github.com/pyannote/pyannote-audio/blob/develop/pyannote/audio/pipelines/utils/diarization.py#L122 + Int32RowVector ComputeSpeakersPerFrame( + const std::vector &labels) const { + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t window_size = meta_data.window_size; + int32_t window_shift = meta_data.window_shift; + int32_t receptive_field_shift = meta_data.receptive_field_shift; + + int32_t num_chunks = labels.size(); + + int32_t num_frames = (window_size + (num_chunks - 1) * window_shift) / + receptive_field_shift + + 1; + + FloatRowVector count(num_frames); + FloatRowVector weight(num_frames); + count.setZero(); + weight.setZero(); + + for (int32_t i = 0; i != num_chunks; ++i) { + int32_t start = + static_cast(i) * window_shift / receptive_field_shift + 0.5; + + auto seq = Eigen::seqN(start, labels[i].rows()); + + count(seq).array() += labels[i].rowwise().sum().array().cast(); + + weight(seq).array() += 1; + } + + return ((count.array() / (weight.array() + 1e-12f)) + 0.5).cast(); + } + + // ans.first: a list of (chunk_id, speaker_id) + // ans.second: a list of list of (start_sample_index, end_sample_index) + // + // ans.first[i] corresponds to ans.second[i] + std::pair, std::vector>> + GetChunkSpeakerSampleIndexes(const std::vector &labels) const { + auto new_labels = ExcludeOverlap(labels); + + std::vector chunk_speaker_list; + std::vector> samples_index_list; + + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t window_size = meta_data.window_size; + int32_t window_shift = meta_data.window_shift; + int32_t receptive_field_shift = meta_data.receptive_field_shift; + int32_t num_speakers = meta_data.num_speakers; + + int32_t chunk_index = 0; + for (const auto &label : new_labels) { + Matrix2DInt32 tmp = label.transpose(); + // tmp: (num_speakers, num_frames) + int32_t num_frames = tmp.cols(); + + int32_t sample_offset = chunk_index * window_shift; + + for (int32_t speaker_index = 0; speaker_index != num_speakers; + ++speaker_index) { + auto d = tmp.row(speaker_index); + if (d.sum() < 10) { + // skip segments less than 10 frames + continue; + } + + Int32Pair this_chunk_speaker = {chunk_index, speaker_index}; + std::vector this_speaker_samples; + + bool is_active = false; + int32_t start_index; + + for (int32_t k = 0; k != num_frames; ++k) { + if (d[k] != 0) { + if (!is_active) { + is_active = true; + start_index = k; + } + } else if (is_active) { + is_active = false; + + int32_t start_samples = + static_cast(start_index) / num_frames * window_size + + sample_offset; + int32_t end_samples = + static_cast(k) / num_frames * window_size + + sample_offset; + + this_speaker_samples.emplace_back(start_samples, end_samples); + } + } + + if (is_active) { + int32_t start_samples = + static_cast(start_index) / num_frames * window_size + + sample_offset; + int32_t end_samples = + static_cast(num_frames - 1) / num_frames * window_size + + sample_offset; + this_speaker_samples.emplace_back(start_samples, end_samples); + } + + chunk_speaker_list.push_back(std::move(this_chunk_speaker)); + samples_index_list.push_back(std::move(this_speaker_samples)); + } // for (int32_t speaker_index = 0; + chunk_index += 1; + } // for (const auto &label : new_labels) + + return {chunk_speaker_list, samples_index_list}; + } + + // If there are multiple speakers at a frame, then this frame is excluded. + std::vector ExcludeOverlap( + const std::vector &labels) const { + int32_t num_chunks = labels.size(); + std::vector ans; + ans.reserve(num_chunks); + + for (const auto &label : labels) { + Matrix2DInt32 new_label(label.rows(), label.cols()); + new_label.setZero(); + Int32RowVector v = label.rowwise().sum(); + + for (int32_t i = 0; i != v.cols(); ++i) { + if (v[i] < 2) { + new_label.row(i) = label.row(i); + } + } + + ans.push_back(std::move(new_label)); + } + + return ans; + } + + /** + * @param sample_indexes[i] contains the sample segment start and end indexes + * for the i-th (chunk, speaker) pair + * @return Return a matrix of shape (sample_indexes.size(), embedding_dim) + * where ans.row[i] contains the embedding for the + * i-th (chunk, speaker) pair + */ + Matrix2D ComputeEmbeddings( + const float *audio, int32_t n, + const std::vector> &sample_indexes, + OfflineSpeakerDiarizationProgressCallback callback, + void *callback_arg) const { + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t sample_rate = meta_data.sample_rate; + Matrix2D ans(sample_indexes.size(), embedding_extractor_.Dim()); + + int32_t k = 0; + for (const auto &v : sample_indexes) { + auto stream = embedding_extractor_.CreateStream(); + for (const auto &p : v) { + int32_t end = (p.second <= n) ? p.second : n; + int32_t num_samples = end - p.first; + + if (num_samples > 0) { + stream->AcceptWaveform(sample_rate, audio + p.first, num_samples); + } + } + + stream->InputFinished(); + if (!embedding_extractor_.IsReady(stream.get())) { + SHERPA_ONNX_LOGE( + "This segment is too short, which should not happen since we have " + "already filtered short segments"); + SHERPA_ONNX_EXIT(-1); + } + + std::vector embedding = embedding_extractor_.Compute(stream.get()); + + std::copy(embedding.begin(), embedding.end(), &ans(k, 0)); + + k += 1; + + if (callback) { + callback(k, ans.rows(), callback_arg); + } + } + + return ans; + } + + std::unordered_map ConvertChunkSpeakerToCluster( + const std::vector &chunk_speaker_pair, + const std::vector &cluster_labels) const { + std::unordered_map ans; + + int32_t k = 0; + for (const auto &p : chunk_speaker_pair) { + ans[p] = cluster_labels[k]; + k += 1; + } + + return ans; + } + + std::vector ReLabel( + const std::vector &labels, int32_t max_cluster_index, + std::unordered_map chunk_speaker_to_cluster) + const { + std::vector new_labels; + new_labels.reserve(labels.size()); + + int32_t chunk_index = 0; + for (const auto &label : labels) { + Matrix2DInt32 new_label(label.rows(), max_cluster_index + 1); + new_label.setZero(); + + Matrix2DInt32 t = label.transpose(); + // t: (num_speakers, num_frames) + + for (int32_t speaker_index = 0; speaker_index != t.rows(); + ++speaker_index) { + if (chunk_speaker_to_cluster.count({chunk_index, speaker_index}) == 0) { + continue; + } + + int32_t new_speaker_index = + chunk_speaker_to_cluster.at({chunk_index, speaker_index}); + + for (int32_t k = 0; k != t.cols(); ++k) { + if (t(speaker_index, k) == 1) { + new_label(k, new_speaker_index) = 1; + } + } + } + + new_labels.push_back(std::move(new_label)); + + chunk_index += 1; + } + + return new_labels; + } + + Matrix2DInt32 ComputeSpeakerCount(const std::vector &labels, + int32_t num_samples) const { + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t window_size = meta_data.window_size; + int32_t window_shift = meta_data.window_shift; + int32_t receptive_field_shift = meta_data.receptive_field_shift; + + int32_t num_chunks = labels.size(); + + int32_t num_frames = (window_size + (num_chunks - 1) * window_shift) / + receptive_field_shift + + 1; + + Matrix2DInt32 count(num_frames, labels[0].cols()); + count.setZero(); + + for (int32_t i = 0; i != num_chunks; ++i) { + int32_t start = + static_cast(i) * window_shift / receptive_field_shift + 0.5; + + auto seq = Eigen::seqN(start, labels[i].rows()); + + count(seq, Eigen::all).array() += labels[i].array(); + } + + bool has_last_chunk = (num_samples - window_size) % window_shift > 0; + + if (has_last_chunk) { + return count; + } + + int32_t last_frame = num_samples / receptive_field_shift; + return count(Eigen::seq(0, last_frame), Eigen::all); + } + + Matrix2DInt32 FinalizeLabels(const Matrix2DInt32 &count, + const Int32RowVector &speakers_per_frame) const { + int32_t num_rows = count.rows(); + int32_t num_cols = count.cols(); + + Matrix2DInt32 ans(num_rows, num_cols); + ans.setZero(); + + for (int32_t i = 0; i != num_rows; ++i) { + int32_t k = speakers_per_frame[i]; + if (k == 0) { + continue; + } + auto top_k = TopkIndex(&count(i, 0), num_cols, k); + + for (int32_t m : top_k) { + ans(i, m) = 1; + } + } + + return ans; + } + + OfflineSpeakerDiarizationResult ComputeResult( + const Matrix2DInt32 &final_labels) const { + Matrix2DInt32 final_labels_t = final_labels.transpose(); + int32_t num_speakers = final_labels_t.rows(); + int32_t num_frames = final_labels_t.cols(); + + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t window_size = meta_data.window_size; + int32_t window_shift = meta_data.window_shift; + int32_t receptive_field_shift = meta_data.receptive_field_shift; + int32_t receptive_field_size = meta_data.receptive_field_size; + int32_t sample_rate = meta_data.sample_rate; + + float scale = static_cast(receptive_field_shift) / sample_rate; + float scale_offset = 0.5 * receptive_field_size / sample_rate; + + OfflineSpeakerDiarizationResult ans; + + for (int32_t speaker_index = 0; speaker_index != num_speakers; + ++speaker_index) { + std::vector this_speaker; + + bool is_active = final_labels_t(speaker_index, 0) > 0; + int32_t start_index = is_active ? 0 : -1; + + for (int32_t frame_index = 1; frame_index != num_frames; ++frame_index) { + if (is_active) { + if (final_labels_t(speaker_index, frame_index) == 0) { + float start_time = start_index * scale + scale_offset; + float end_time = frame_index * scale + scale_offset; + + OfflineSpeakerDiarizationSegment segment(start_time, end_time, + speaker_index); + this_speaker.push_back(segment); + + is_active = false; + } + } else if (final_labels_t(speaker_index, frame_index) == 1) { + is_active = true; + start_index = frame_index; + } + } + + if (is_active) { + float start_time = start_index * scale + scale_offset; + float end_time = (num_frames - 1) * scale + scale_offset; + + OfflineSpeakerDiarizationSegment segment(start_time, end_time, + speaker_index); + this_speaker.push_back(segment); + } + + // merge segments if the gap between them is less than min_duration_off + MergeSegments(&this_speaker); + + for (const auto &seg : this_speaker) { + if (seg.Duration() > config_.min_duration_on) { + ans.Add(seg); + } + } + } // for (int32_t speaker_index = 0; speaker_index != num_speakers; + + return ans; + } + + void MergeSegments( + std::vector *segments) const { + float min_duration_off = config_.min_duration_off; + bool changed = true; + while (changed) { + changed = false; + for (int32_t i = 0; i < static_cast(segments->size()) - 1; ++i) { + auto s = (*segments)[i].Merge((*segments)[i + 1], min_duration_off); + if (s) { + (*segments)[i] = s.value(); + segments->erase(segments->begin() + i + 1); + + changed = true; + break; + } + } + } + } + + private: + OfflineSpeakerDiarizationConfig config_; + OfflineSpeakerSegmentationPyannoteModel segmentation_model_; + SpeakerEmbeddingExtractor embedding_extractor_; + FastClustering clustering_; + Matrix2DInt32 powerset_mapping_; +}; + +} // namespace sherpa_onnx +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-result.cc b/sherpa-onnx/csrc/offline-speaker-diarization-result.cc new file mode 100644 index 000000000..8bf83f5d9 --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-diarization-result.cc @@ -0,0 +1,110 @@ +// sherpa-onnx/csrc/offline-speaker-diarization-result.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-speaker-diarization-result.h" + +#include +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +OfflineSpeakerDiarizationSegment::OfflineSpeakerDiarizationSegment( + float start, float end, int32_t speaker, const std::string &text /*= {}*/) { + if (start > end) { + SHERPA_ONNX_LOGE("start %.3f should be less than end %.3f", start, end); + SHERPA_ONNX_EXIT(-1); + } + + start_ = start; + end_ = end; + speaker_ = speaker; + text_ = text; +} + +std::optional +OfflineSpeakerDiarizationSegment::Merge( + const OfflineSpeakerDiarizationSegment &other, float gap) const { + if (other.speaker_ != speaker_) { + SHERPA_ONNX_LOGE( + "The two segments should have the same speaker. this->speaker: %d, " + "other.speaker: %d", + speaker_, other.speaker_); + return std::nullopt; + } + + if (end_ < other.start_ && end_ + gap >= other.start_) { + return OfflineSpeakerDiarizationSegment(start_, other.end_, speaker_); + } else if (other.end_ < start_ && other.end_ + gap >= start_) { + return OfflineSpeakerDiarizationSegment(other.start_, end_, speaker_); + } else { + return std::nullopt; + } +} + +std::string OfflineSpeakerDiarizationSegment::ToString() const { + char s[128]; + snprintf(s, sizeof(s), "%.3f -- %.3f speaker_%02d", start_, end_, speaker_); + + std::ostringstream os; + os << s; + + if (!text_.empty()) { + os << " " << text_; + } + + return os.str(); +} + +void OfflineSpeakerDiarizationResult::Add( + const OfflineSpeakerDiarizationSegment &segment) { + segments_.push_back(segment); +} + +int32_t OfflineSpeakerDiarizationResult::NumSpeakers() const { + std::unordered_set count; + for (const auto &s : segments_) { + count.insert(s.Speaker()); + } + + return count.size(); +} + +int32_t OfflineSpeakerDiarizationResult::NumSegments() const { + return segments_.size(); +} + +// Return a list of segments sorted by segment.start time +std::vector +OfflineSpeakerDiarizationResult::SortByStartTime() const { + auto ans = segments_; + std::sort(ans.begin(), ans.end(), [](const auto &a, const auto &b) { + return (a.Start() < b.Start()) || + ((a.Start() == b.Start()) && (a.Speaker() < b.Speaker())); + }); + + return ans; +} + +std::vector> +OfflineSpeakerDiarizationResult::SortBySpeaker() const { + auto tmp = segments_; + std::sort(tmp.begin(), tmp.end(), [](const auto &a, const auto &b) { + return (a.Speaker() < b.Speaker()) || + ((a.Speaker() == b.Speaker()) && (a.Start() < b.Start())); + }); + + std::vector> ans(NumSpeakers()); + for (auto &s : tmp) { + ans[s.Speaker()].push_back(std::move(s)); + } + + return ans; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-result.h b/sherpa-onnx/csrc/offline-speaker-diarization-result.h new file mode 100644 index 000000000..e71d054e5 --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-diarization-result.h @@ -0,0 +1,65 @@ +// sherpa-onnx/csrc/offline-speaker-diarization-result.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_ + +#include +#include +#include +#include + +namespace sherpa_onnx { + +class OfflineSpeakerDiarizationSegment { + public: + OfflineSpeakerDiarizationSegment(float start, float end, int32_t speaker, + const std::string &text = {}); + + // If the gap between the two segments is less than the given gap, then we + // merge them and return a new segment. Otherwise, it returns null. + std::optional Merge( + const OfflineSpeakerDiarizationSegment &other, float gap) const; + + float Start() const { return start_; } + float End() const { return end_; } + int32_t Speaker() const { return speaker_; } + const std::string &Text() const { return text_; } + float Duration() const { return end_ - start_; } + + std::string ToString() const; + + private: + float start_; // in seconds + float end_; // in seconds + int32_t speaker_; // ID of the speaker, starting from 0 + std::string text_; // If not empty, it contains the speech recognition result + // of this segment +}; + +class OfflineSpeakerDiarizationResult { + public: + // Add a new segment + void Add(const OfflineSpeakerDiarizationSegment &segment); + + // Number of distinct speakers contained in this object at this point + int32_t NumSpeakers() const; + + int32_t NumSegments() const; + + // Return a list of segments sorted by segment.start time + std::vector SortByStartTime() const; + + // ans.size() == NumSpeakers(). + // ans[i] is for speaker_i and is sorted by start time + std::vector> SortBySpeaker() + const; + + public: + std::vector segments_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_ diff --git a/sherpa-onnx/csrc/offline-speaker-diarization.cc b/sherpa-onnx/csrc/offline-speaker-diarization.cc new file mode 100644 index 000000000..aeff9b42d --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-diarization.cc @@ -0,0 +1,79 @@ +// sherpa-onnx/csrc/offline-speaker-diarization.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-speaker-diarization.h" + +#include + +#include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h" + +namespace sherpa_onnx { + +void OfflineSpeakerDiarizationConfig::Register(ParseOptions *po) { + ParseOptions po_segmentation("segmentation", po); + segmentation.Register(&po_segmentation); + + ParseOptions po_embedding("embedding", po); + embedding.Register(&po_embedding); + + ParseOptions po_clustering("clustering", po); + clustering.Register(&po_clustering); + + po->Register("min-duration-on", &min_duration_on, + "if a segment is less than this value, then it is discarded. " + "Set it to 0 so that no segment is discarded"); + + po->Register("min-duration-off", &min_duration_off, + "if the gap between to segments of the same speaker is less " + "than this value, then these two segments are merged into a " + "single segment. We do it recursively."); +} + +bool OfflineSpeakerDiarizationConfig::Validate() const { + if (!segmentation.Validate()) { + return false; + } + + if (!embedding.Validate()) { + return false; + } + + if (!clustering.Validate()) { + return false; + } + + return true; +} + +std::string OfflineSpeakerDiarizationConfig::ToString() const { + std::ostringstream os; + + os << "OfflineSpeakerDiarizationConfig("; + os << "segmentation=" << segmentation.ToString() << ", "; + os << "embedding=" << embedding.ToString() << ", "; + os << "clustering=" << clustering.ToString() << ", "; + os << "min_duration_on=" << min_duration_on << ", "; + os << "min_duration_off=" << min_duration_off << ")"; + + return os.str(); +} + +OfflineSpeakerDiarization::OfflineSpeakerDiarization( + const OfflineSpeakerDiarizationConfig &config) + : impl_(OfflineSpeakerDiarizationImpl::Create(config)) {} + +OfflineSpeakerDiarization::~OfflineSpeakerDiarization() = default; + +int32_t OfflineSpeakerDiarization::SampleRate() const { + return impl_->SampleRate(); +} + +OfflineSpeakerDiarizationResult OfflineSpeakerDiarization::Process( + const float *audio, int32_t n, + OfflineSpeakerDiarizationProgressCallback callback /*= nullptr*/, + void *callback_arg /*= nullptr*/) const { + return impl_->Process(audio, n, callback, callback_arg); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-speaker-diarization.h b/sherpa-onnx/csrc/offline-speaker-diarization.h new file mode 100644 index 000000000..ab9a440aa --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-diarization.h @@ -0,0 +1,73 @@ +// sherpa-onnx/csrc/offline-speaker-diarization.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_ + +#include +#include +#include + +#include "sherpa-onnx/csrc/fast-clustering-config.h" +#include "sherpa-onnx/csrc/offline-speaker-diarization-result.h" +#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h" +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h" + +namespace sherpa_onnx { + +struct OfflineSpeakerDiarizationConfig { + OfflineSpeakerSegmentationModelConfig segmentation; + SpeakerEmbeddingExtractorConfig embedding; + FastClusteringConfig clustering; + + // if a segment is less than this value, then it is discarded + float min_duration_on = 0.3; // in seconds + + // if the gap between to segments of the same speaker is less than this value, + // then these two segments are merged into a single segment. + // We do this recursively. + float min_duration_off = 0.5; // in seconds + + OfflineSpeakerDiarizationConfig() = default; + + OfflineSpeakerDiarizationConfig( + const OfflineSpeakerSegmentationModelConfig &segmentation, + const SpeakerEmbeddingExtractorConfig &embedding, + const FastClusteringConfig &clustering) + : segmentation(segmentation), + embedding(embedding), + clustering(clustering) {} + + void Register(ParseOptions *po); + bool Validate() const; + std::string ToString() const; +}; + +class OfflineSpeakerDiarizationImpl; + +using OfflineSpeakerDiarizationProgressCallback = std::function; + +class OfflineSpeakerDiarization { + public: + explicit OfflineSpeakerDiarization( + const OfflineSpeakerDiarizationConfig &config); + + ~OfflineSpeakerDiarization(); + + // Expected sample rate of the input audio samples + int32_t SampleRate() const; + + OfflineSpeakerDiarizationResult Process( + const float *audio, int32_t n, + OfflineSpeakerDiarizationProgressCallback callback = nullptr, + void *callback_arg = nullptr) const; + + private: + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_ diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-model-config.cc b/sherpa-onnx/csrc/offline-speaker-segmentation-model-config.cc new file mode 100644 index 000000000..f1c9f7d4a --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-model-config.cc @@ -0,0 +1,57 @@ +// sherpa-onnx/csrc/offline-speaker-segmentation-model-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation +#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h" + +#include +#include + +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflineSpeakerSegmentationModelConfig::Register(ParseOptions *po) { + pyannote.Register(po); + + po->Register("num-threads", &num_threads, + "Number of threads to run the neural network"); + + po->Register("debug", &debug, + "true to print model information while loading it."); + + po->Register("provider", &provider, + "Specify a provider to use: cpu, cuda, coreml"); +} + +bool OfflineSpeakerSegmentationModelConfig::Validate() const { + if (num_threads < 1) { + SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); + return false; + } + + if (!pyannote.model.empty()) { + return pyannote.Validate(); + } + + if (pyannote.model.empty()) { + SHERPA_ONNX_LOGE( + "You have to provide at least one speaker segmentation model"); + return false; + } + + return true; +} + +std::string OfflineSpeakerSegmentationModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineSpeakerSegmentationModelConfig("; + os << "pyannote=" << pyannote.ToString() << ", "; + os << "num_threads=" << num_threads << ", "; + os << "debug=" << (debug ? "True" : "False") << ", "; + os << "provider=\"" << provider << "\")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h b/sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h new file mode 100644 index 000000000..8e9e4a96e --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h @@ -0,0 +1,40 @@ +// sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h" +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineSpeakerSegmentationModelConfig { + OfflineSpeakerSegmentationPyannoteModelConfig pyannote; + + int32_t num_threads = 1; + bool debug = false; + std::string provider = "cpu"; + + OfflineSpeakerSegmentationModelConfig() = default; + + explicit OfflineSpeakerSegmentationModelConfig( + const OfflineSpeakerSegmentationPyannoteModelConfig &pyannote, + int32_t num_threads, bool debug, const std::string &provider) + : pyannote(pyannote), + num_threads(num_threads), + debug(debug), + provider(provider) {} + + void Register(ParseOptions *po); + + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.cc b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.cc new file mode 100644 index 000000000..f7417ea83 --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.cc @@ -0,0 +1,38 @@ +// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation +#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h" + +#include +#include + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflineSpeakerSegmentationPyannoteModelConfig::Register(ParseOptions *po) { + po->Register("pyannote-model", &model, + "Path to model.onnx of the Pyannote segmentation model."); +} + +bool OfflineSpeakerSegmentationPyannoteModelConfig::Validate() const { + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("Pyannote segmentation model: '%s' does not exist", + model.c_str()); + return false; + } + + return true; +} + +std::string OfflineSpeakerSegmentationPyannoteModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineSpeakerSegmentationPyannoteModelConfig("; + os << "model=\"" << model << "\")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h new file mode 100644 index 000000000..fb5ca4a48 --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h @@ -0,0 +1,30 @@ +// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_CONFIG_H_ +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineSpeakerSegmentationPyannoteModelConfig { + std::string model; + + OfflineSpeakerSegmentationPyannoteModelConfig() = default; + + explicit OfflineSpeakerSegmentationPyannoteModelConfig( + const std::string &model) + : model(model) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h new file mode 100644 index 000000000..728ed7ff4 --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h @@ -0,0 +1,29 @@ +// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_ + +#include +#include + +namespace sherpa_onnx { + +// If you are not sure what each field means, please +// have a look of the Python file in the model directory that +// you have downloaded. +struct OfflineSpeakerSegmentationPyannoteModelMetaData { + int32_t sample_rate = 0; + int32_t window_size = 0; // in samples + int32_t window_shift = 0; // in samples + int32_t receptive_field_size = 0; // in samples + int32_t receptive_field_shift = 0; // in samples + int32_t num_speakers = 0; + int32_t powerset_max_classes = 0; + int32_t num_classes = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_ diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc new file mode 100644 index 000000000..3f3323698 --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc @@ -0,0 +1,108 @@ +// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" + +namespace sherpa_onnx { + +class OfflineSpeakerSegmentationPyannoteModel::Impl { + public: + explicit Impl(const OfflineSpeakerSegmentationModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(config_.pyannote.model); + Init(buf.data(), buf.size()); + } + + const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData() + const { + return meta_data_; + } + + Ort::Value Forward(Ort::Value x) { + auto out = sess_->Run({}, input_names_ptr_.data(), &x, 1, + output_names_ptr_.data(), output_names_ptr_.size()); + + return std::move(out[0]); + } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::make_unique(env_, model_data, model_data_length, + sess_opts_); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate"); + SHERPA_ONNX_READ_META_DATA(meta_data_.window_size, "window_size"); + + meta_data_.window_shift = + static_cast(0.1 * meta_data_.window_size); + + SHERPA_ONNX_READ_META_DATA(meta_data_.receptive_field_size, + "receptive_field_size"); + SHERPA_ONNX_READ_META_DATA(meta_data_.receptive_field_shift, + "receptive_field_shift"); + SHERPA_ONNX_READ_META_DATA(meta_data_.num_speakers, "num_speakers"); + SHERPA_ONNX_READ_META_DATA(meta_data_.powerset_max_classes, + "powerset_max_classes"); + SHERPA_ONNX_READ_META_DATA(meta_data_.num_classes, "num_classes"); + } + + private: + OfflineSpeakerSegmentationModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + OfflineSpeakerSegmentationPyannoteModelMetaData meta_data_; +}; + +OfflineSpeakerSegmentationPyannoteModel:: + OfflineSpeakerSegmentationPyannoteModel( + const OfflineSpeakerSegmentationModelConfig &config) + : impl_(std::make_unique(config)) {} + +OfflineSpeakerSegmentationPyannoteModel:: + ~OfflineSpeakerSegmentationPyannoteModel() = default; + +const OfflineSpeakerSegmentationPyannoteModelMetaData & +OfflineSpeakerSegmentationPyannoteModel::GetModelMetaData() const { + return impl_->GetModelMetaData(); +} + +Ort::Value OfflineSpeakerSegmentationPyannoteModel::Forward( + Ort::Value x) const { + return impl_->Forward(std::move(x)); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h new file mode 100644 index 000000000..b504c373f --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h @@ -0,0 +1,40 @@ +// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_ + +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h" +#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h" + +namespace sherpa_onnx { + +class OfflineSpeakerSegmentationPyannoteModel { + public: + explicit OfflineSpeakerSegmentationPyannoteModel( + const OfflineSpeakerSegmentationModelConfig &config); + + ~OfflineSpeakerSegmentationPyannoteModel(); + + const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData() + const; + + /** + * @param x A 3-D float tensor of shape (batch_size, 1, num_samples) + * @return Return a float tensor of + * shape (batch_size, num_frames, num_speakers). Note that + * num_speakers here uses powerset encoding. + */ + Ort::Value Forward(Ort::Value x) const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_ diff --git a/sherpa-onnx/csrc/provider-config.cc b/sherpa-onnx/csrc/provider-config.cc index 1db62aa6b..165e2d9a2 100644 --- a/sherpa-onnx/csrc/provider-config.cc +++ b/sherpa-onnx/csrc/provider-config.cc @@ -61,8 +61,10 @@ void TensorrtConfig::Register(ParseOptions *po) { bool TensorrtConfig::Validate() const { if (trt_max_workspace_size < 0) { - SHERPA_ONNX_LOGE("trt_max_workspace_size: %ld is not valid.", - trt_max_workspace_size); + std::ostringstream os; + os << "trt_max_workspace_size: " << trt_max_workspace_size + << " is not valid."; + SHERPA_ONNX_LOGE("%s", os.str().c_str()); return false; } if (trt_max_partition_iterations < 0) { diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index 7f6f685e0..9c5eb2b1a 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -35,9 +35,9 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) { api.ReleaseStatus(status); } -static Ort::SessionOptions GetSessionOptionsImpl( +Ort::SessionOptions GetSessionOptionsImpl( int32_t num_threads, const std::string &provider_str, - const ProviderConfig *provider_config = nullptr) { + const ProviderConfig *provider_config /*= nullptr*/) { Provider p = StringToProvider(provider_str); Ort::SessionOptions sess_opts; @@ -259,10 +259,6 @@ Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, &config.provider_config); } -Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) { - return GetSessionOptionsImpl(config.num_threads, config.provider); -} - Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config) { return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider); } @@ -271,38 +267,4 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config) { return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider); } -Ort::SessionOptions GetSessionOptions(const VadModelConfig &config) { - return GetSessionOptionsImpl(config.num_threads, config.provider); -} - -#if SHERPA_ONNX_ENABLE_TTS -Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) { - return GetSessionOptionsImpl(config.num_threads, config.provider); -} -#endif - -Ort::SessionOptions GetSessionOptions( - const SpeakerEmbeddingExtractorConfig &config) { - return GetSessionOptionsImpl(config.num_threads, config.provider); -} - -Ort::SessionOptions GetSessionOptions( - const SpokenLanguageIdentificationConfig &config) { - return GetSessionOptionsImpl(config.num_threads, config.provider); -} - -Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config) { - return GetSessionOptionsImpl(config.num_threads, config.provider); -} - -Ort::SessionOptions GetSessionOptions( - const OfflinePunctuationModelConfig &config) { - return GetSessionOptionsImpl(config.num_threads, config.provider); -} - -Ort::SessionOptions GetSessionOptions( - const OnlinePunctuationModelConfig &config) { - return GetSessionOptionsImpl(config.num_threads, config.provider); -} - } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/session.h b/sherpa-onnx/csrc/session.h index 1e8beb114..e19db6c20 100644 --- a/sherpa-onnx/csrc/session.h +++ b/sherpa-onnx/csrc/session.h @@ -8,53 +8,28 @@ #include #include "onnxruntime_cxx_api.h" // NOLINT -#include "sherpa-onnx/csrc/audio-tagging-model-config.h" #include "sherpa-onnx/csrc/offline-lm-config.h" -#include "sherpa-onnx/csrc/offline-model-config.h" -#include "sherpa-onnx/csrc/offline-punctuation-model-config.h" -#include "sherpa-onnx/csrc/online-punctuation-model-config.h" #include "sherpa-onnx/csrc/online-lm-config.h" #include "sherpa-onnx/csrc/online-model-config.h" -#include "sherpa-onnx/csrc/speaker-embedding-extractor.h" -#include "sherpa-onnx/csrc/spoken-language-identification.h" -#include "sherpa-onnx/csrc/vad-model-config.h" - -#if SHERPA_ONNX_ENABLE_TTS -#include "sherpa-onnx/csrc/offline-tts-model-config.h" -#endif namespace sherpa_onnx { -Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config); - -Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, - const std::string &model_type); - -Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config); +Ort::SessionOptions GetSessionOptionsImpl( + int32_t num_threads, const std::string &provider_str, + const ProviderConfig *provider_config = nullptr); Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config); - Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config); -Ort::SessionOptions GetSessionOptions(const VadModelConfig &config); - -#if SHERPA_ONNX_ENABLE_TTS -Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config); -#endif - -Ort::SessionOptions GetSessionOptions( - const SpeakerEmbeddingExtractorConfig &config); - -Ort::SessionOptions GetSessionOptions( - const SpokenLanguageIdentificationConfig &config); - -Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config); +Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config); -Ort::SessionOptions GetSessionOptions( - const OfflinePunctuationModelConfig &config); +Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, + const std::string &model_type); -Ort::SessionOptions GetSessionOptions( - const OnlinePunctuationModelConfig &config); +template +Ort::SessionOptions GetSessionOptions(const T &config) { + return GetSessionOptionsImpl(config.num_threads, config.provider); +} } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc new file mode 100644 index 000000000..170973114 --- /dev/null +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc @@ -0,0 +1,133 @@ +// sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-speaker-diarization.h" +#include "sherpa-onnx/csrc/parse-options.h" +#include "sherpa-onnx/csrc/wave-reader.h" + +static int32_t ProgressCallback(int32_t processed_chunks, int32_t num_chunks, + void *arg) { + float progress = 100.0 * processed_chunks / num_chunks; + fprintf(stderr, "progress %.2f%%\n", progress); + + // the return value is currently ignored + return 0; +} + +int main(int32_t argc, char *argv[]) { + const char *kUsageMessage = R"usage( +Offline/Non-streaming speaker diarization with sherpa-onnx +Usage example: + +Step 1: Download a speaker segmentation model + +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models +for a list of available models. The following is an example + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + +Step 2: Download a speaker embedding extractor model + +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models +for a list of available models. The following is an example + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx + +Step 3. Download test wave files + +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models +for a list of available test wave files. The following is an example + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav + +Step 4. Build sherpa-onnx + +Step 5. Run it + + ./bin/sherpa-onnx-offline-speaker-diarization \ + --clustering.num-clusters=4 \ + --segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \ + --embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \ + ./0-four-speakers-zh.wav + +Since we know that there are four speakers in the test wave file, we use +--clustering.num-clusters=4 in the above example. + +If we don't know number of speakers in the given wave file, we can use +the argument --clustering.cluster-threshold. The following is an example: + + ./bin/sherpa-onnx-offline-speaker-diarization \ + --clustering.cluster-threshold=0.90 \ + --segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \ + --embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \ + ./0-four-speakers-zh.wav + +A larger threshold leads to few clusters, i.e., few speakers; +a smaller threshold leads to more clusters, i.e., more speakers + )usage"; + sherpa_onnx::OfflineSpeakerDiarizationConfig config; + sherpa_onnx::ParseOptions po(kUsageMessage); + config.Register(&po); + po.Read(argc, argv); + + std::cout << config.ToString() << "\n"; + + if (!config.Validate()) { + po.PrintUsage(); + std::cerr << "Errors in config!\n"; + return -1; + } + + if (po.NumArgs() != 1) { + std::cerr << "Error: Please provide exactly 1 wave file.\n\n"; + po.PrintUsage(); + return -1; + } + + sherpa_onnx::OfflineSpeakerDiarization sd(config); + + std::cout << "Started\n"; + const auto begin = std::chrono::steady_clock::now(); + const std::string wav_filename = po.GetArg(1); + int32_t sample_rate = -1; + bool is_ok = false; + const std::vector samples = + sherpa_onnx::ReadWave(wav_filename, &sample_rate, &is_ok); + if (!is_ok) { + std::cerr << "Failed to read " << wav_filename.c_str() << "\n"; + return -1; + } + + if (sample_rate != sd.SampleRate()) { + std::cerr << "Expect sample rate " << sd.SampleRate() + << ". Given: " << sample_rate << "\n"; + return -1; + } + + float duration = samples.size() / static_cast(sample_rate); + + auto result = + sd.Process(samples.data(), samples.size(), ProgressCallback, nullptr) + .SortByStartTime(); + + for (const auto &r : result) { + std::cout << r.ToString() << "\n"; + } + + const auto end = std::chrono::steady_clock::now(); + float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + + fprintf(stderr, "Duration : %.3f s\n", duration); + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); + float rtf = elapsed_seconds / duration; + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", + elapsed_seconds, duration, rtf); + + return 0; +} diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc index 442ec1813..1ab8b68de 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc @@ -9,14 +9,15 @@ #include "sherpa-onnx/csrc/parse-options.h" #include "sherpa-onnx/csrc/wave-writer.h" -int32_t audioCallback(const float * /*samples*/, int32_t n, float progress) { +static int32_t AudioCallback(const float * /*samples*/, int32_t n, + float progress) { printf("sample=%d, progress=%f\n", n, progress); return 1; } int main(int32_t argc, char *argv[]) { const char *kUsageMessage = R"usage( -Offline text-to-speech with sherpa-onnx +Offline/Non-streaming text-to-speech with sherpa-onnx Usage example: @@ -79,7 +80,7 @@ or details. sherpa_onnx::OfflineTts tts(config); const auto begin = std::chrono::steady_clock::now(); - auto audio = tts.Generate(po.GetArg(1), sid, 1.0, audioCallback); + auto audio = tts.Generate(po.GetArg(1), sid, 1.0, AudioCallback); const auto end = std::chrono::steady_clock::now(); if (audio.samples.empty()) { diff --git a/sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc b/sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc index ea83cfaaf..faca83b98 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc @@ -19,7 +19,7 @@ The input text can contain English words. Usage: Please download the model from: -https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 +https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-online-punct-en-2024-08-06.tar.bz2 ./bin/Release/sherpa-onnx-online-punctuation \ --cnn-bilstm=/path/to/model.onnx \ diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor.cc b/sherpa-onnx/csrc/speaker-embedding-extractor.cc index 1c99de1a0..d90b0b1e0 100644 --- a/sherpa-onnx/csrc/speaker-embedding-extractor.cc +++ b/sherpa-onnx/csrc/speaker-embedding-extractor.cc @@ -26,12 +26,12 @@ void SpeakerEmbeddingExtractorConfig::Register(ParseOptions *po) { bool SpeakerEmbeddingExtractorConfig::Validate() const { if (model.empty()) { - SHERPA_ONNX_LOGE("Please provide --model"); + SHERPA_ONNX_LOGE("Please provide a speaker embedding extractor model"); return false; } if (!FileExists(model)) { - SHERPA_ONNX_LOGE("--speaker-embedding-model: '%s' does not exist", + SHERPA_ONNX_LOGE("speaker embedding extractor model: '%s' does not exist", model.c_str()); return false; }