Skip to content

Commit

Permalink
Handle NaN embeddings in speaker diarization. (#1461)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Oct 24, 2024
1 parent b3e05f6 commit a5295aa
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 7 deletions.
2 changes: 1 addition & 1 deletion cxx-api-examples/sense-voice-cxx-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include "sherpa-onnx/c-api/cxx-api.h"

int32_t main() {
using namespace sherpa_onnx::cxx;
using namespace sherpa_onnx::cxx; // NOLINT
OfflineRecognizerConfig config;

config.model_config.sense_voice.model =
Expand Down
2 changes: 1 addition & 1 deletion cxx-api-examples/streaming-zipformer-cxx-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include "sherpa-onnx/c-api/cxx-api.h"

int32_t main() {
using namespace sherpa_onnx::cxx;
using namespace sherpa_onnx::cxx; // NOLINT
OnlineRecognizerConfig config;

// please see
Expand Down
2 changes: 1 addition & 1 deletion cxx-api-examples/whisper-cxx-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include "sherpa-onnx/c-api/cxx-api.h"

int32_t main() {
using namespace sherpa_onnx::cxx;
using namespace sherpa_onnx::cxx; // NOLINT
OfflineRecognizerConfig config;

config.model_config.whisper.encoder =
Expand Down
5 changes: 4 additions & 1 deletion scripts/check_style_cpplint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ function is_source_code_file() {
}

function check_style() {
if [[ $1 == mfc-example* ]]; then
return
fi
python3 $cpplint_src $1 || abort $1
}

Expand Down Expand Up @@ -99,7 +102,7 @@ function do_check() {
;;
2)
echo "Check all files"
files=$(find $sherpa_onnx_dir/sherpa-onnx/csrc $sherpa_onnx_dir/sherpa-onnx/python $sherpa_onnx_dir/scripts/node-addon-api/src $sherpa_onnx_dir/sherpa-onnx/jni $sherpa_onnx_dir/sherpa-onnx/c-api -name "*.h" -o -name "*.cc")
files=$(find $sherpa_onnx_dir/cxx-api-examples $sherpa_onnx_dir/c-api-examples $sherpa_onnx_dir/sherpa-onnx/csrc $sherpa_onnx_dir/sherpa-onnx/python $sherpa_onnx_dir/scripts/node-addon-api/src $sherpa_onnx_dir/sherpa-onnx/jni $sherpa_onnx_dir/sherpa-onnx/c-api -name "*.h" -o -name "*.cc")
;;
*)
echo "Check last commit"
Expand Down
42 changes: 40 additions & 2 deletions sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_

#include <algorithm>
#include <cmath>
#include <memory>
#include <unordered_map>
#include <utility>
Expand Down Expand Up @@ -135,9 +136,32 @@ class OfflineSpeakerDiarizationPyannoteImpl
}

auto chunk_speaker_samples_list_pair = GetChunkSpeakerSampleIndexes(labels);

// The embedding model may output NaN. valid_indexes contains indexes
// in chunk_speaker_samples_list_pair.second that don't lead to
// NaN embeddings.
std::vector<int32_t> valid_indexes;
valid_indexes.reserve(chunk_speaker_samples_list_pair.second.size());

Matrix2D embeddings =
ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second,
std::move(callback), callback_arg);
&valid_indexes, std::move(callback), callback_arg);

if (valid_indexes.size() != chunk_speaker_samples_list_pair.second.size()) {
std::vector<Int32Pair> chunk_speaker_pair;
std::vector<std::vector<Int32Pair>> sample_indexes;

chunk_speaker_pair.reserve(valid_indexes.size());
sample_indexes.reserve(valid_indexes.size());
for (auto i : valid_indexes) {
chunk_speaker_pair.push_back(chunk_speaker_samples_list_pair.first[i]);
sample_indexes.push_back(
std::move(chunk_speaker_samples_list_pair.second[i]));
}

chunk_speaker_samples_list_pair.first = std::move(chunk_speaker_pair);
chunk_speaker_samples_list_pair.second = std::move(sample_indexes);
}

std::vector<int32_t> cluster_labels = clustering_->Cluster(
&embeddings(0, 0), embeddings.rows(), embeddings.cols());
Expand Down Expand Up @@ -431,13 +455,17 @@ class OfflineSpeakerDiarizationPyannoteImpl
Matrix2D ComputeEmbeddings(
const float *audio, int32_t n,
const std::vector<std::vector<Int32Pair>> &sample_indexes,
std::vector<int32_t> *valid_indexes,
OfflineSpeakerDiarizationProgressCallback callback,
void *callback_arg) const {
const auto &meta_data = segmentation_model_.GetModelMetaData();
int32_t sample_rate = meta_data.sample_rate;
Matrix2D ans(sample_indexes.size(), embedding_extractor_.Dim());

auto IsNaNWrapper = [](float f) -> bool { return std::isnan(f); };

int32_t k = 0;
int32_t cur_row_index = 0;
for (const auto &v : sample_indexes) {
auto stream = embedding_extractor_.CreateStream();
for (const auto &p : v) {
Expand All @@ -459,7 +487,12 @@ class OfflineSpeakerDiarizationPyannoteImpl

std::vector<float> embedding = embedding_extractor_.Compute(stream.get());

std::copy(embedding.begin(), embedding.end(), &ans(k, 0));
if (std::none_of(embedding.begin(), embedding.end(), IsNaNWrapper)) {
// a valid embedding
std::copy(embedding.begin(), embedding.end(), &ans(cur_row_index, 0));
cur_row_index += 1;
valid_indexes->push_back(k);
}

k += 1;

Expand All @@ -468,6 +501,11 @@ class OfflineSpeakerDiarizationPyannoteImpl
}
}

if (k != cur_row_index) {
auto seq = Eigen::seqN(0, cur_row_index);
ans = ans(seq, Eigen::all);
}

return ans;
}

Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class SpeakerEmbeddingExtractorNeMoImpl : public SpeakerEmbeddingExtractorImpl {
auto variance = EX2 - EX.array().pow(2);
auto stddev = variance.array().sqrt();

m = (m.rowwise() - EX).array().rowwise() / stddev.array();
m = (m.rowwise() - EX).array().rowwise() / (stddev.array() + 1e-5);
}

private:
Expand Down

0 comments on commit a5295aa

Please sign in to comment.