Skip to content

Commit

Permalink
Support pre-trained CTC models from NeMo (#332)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Mar 10, 2023
1 parent 9f30bfe commit 32da448
Show file tree
Hide file tree
Showing 39 changed files with 445 additions and 133 deletions.
26 changes: 26 additions & 0 deletions .github/scripts/run-offline-ctc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -336,3 +336,29 @@ log "Decoding with HLG"
rm -rf $repo
log "End of testing ${repo_url}"
log "=========================================================================="

repo_url=https://huggingface.co/csukuangfj/sherpa-nemo-ctc-en-citrinet-512
log "Start testing ${repo_url}"
repo=$(basename $repo_url)
log "Download pretrained model and test-data from $repo_url"

GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo
git lfs pull --include "model.pt"
popd

log "Decoding with H"

./build/bin/sherpa-offline \
--nn-model=$repo/model.pt \
--tokens=$repo/tokens.txt \
--use-gpu=false \
--modified=false \
--nemo-normalize=per_feature \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav

rm -rf $repo
log "End of testing ${repo_url}"
log "=========================================================================="
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ errs-*
rtf-*
*.wav
Testing
run-offline-ctc*.sh
run-offline-asr*.sh
3 changes: 2 additions & 1 deletion cmake/asio.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ function(download_asio)

foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(asio_URL "file://${f}")
set(asio_URL "${f}")
file(TO_CMAKE_PATH "${asio_URL}" asio_URL)
set(asio_URL2)
break()
endif()
Expand Down
3 changes: 2 additions & 1 deletion cmake/googletest.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ function(download_googltest)

foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(googletest_URL "file://${f}")
set(googletest_URL "${f}")
file(TO_CMAKE_PATH "${googletest_URL}" googletest_URL)
set(googletest_URL2)
break()
endif()
Expand Down
3 changes: 2 additions & 1 deletion cmake/json.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ function(download_json)

foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(json_URL "file://${f}")
set(json_URL "${f}")
file(TO_CMAKE_PATH "${json_URL}" json_URL)
set(json_URL2)
break()
endif()
Expand Down
3 changes: 2 additions & 1 deletion cmake/kaldi_native_io.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ function(download_kaldi_native_io)

foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(kaldi_native_io_URL "file://${f}")
set(kaldi_native_io_URL "${f}")
file(TO_CMAKE_PATH "${kaldi_native_io_URL}" kaldi_native_io_URL)
set(kaldi_native_io_URL2)
break()
endif()
Expand Down
3 changes: 2 additions & 1 deletion cmake/portaudio.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ function(download_portaudio)

foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(portaudio_URL "file://${f}")
set(portaudio_URL "${f}")
file(TO_CMAKE_PATH "${portaudio_URL}" portaudio_URL)
set(portaudio_URL2)
break()
endif()
Expand Down
3 changes: 2 additions & 1 deletion cmake/pybind11.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ function(download_pybind11)

foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(pybind11_URL "file://${f}")
set(pybind11_URL "${f}")
file(TO_CMAKE_PATH "${pybind11_URL}" pybind11_URL)
set(pybind11_URL2)
break()
endif()
Expand Down
3 changes: 2 additions & 1 deletion cmake/websocketpp.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ function(download_websocketpp)

foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(websocketpp_URL "file://${f}")
set(websocketpp_URL "${f}")
file(TO_CMAKE_PATH "${websocketpp_URL}" websocketpp_URL)
set(websocketpp_URL2)
break()
endif()
Expand Down
31 changes: 30 additions & 1 deletion sherpa/bin/offline_ctc_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,23 @@
./wav2vec2.0-torchaudio/test_wavs/1089-134686-0001.wav \
./wav2vec2.0-torchaudio/test_wavs/1221-135766-0001.wav \
./wav2vec2.0-torchaudio/test_wavs/1221-135766-0002.wav
(4) Use NeMo CTC models
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-nemo-ctc-en-citrinet-512
cd sherpa-nemo-ctc-en-citrinet-512
git lfs pull --include "model.pt"
cd /path/to/sherpa
./sherpa/bin/offline_ctc_asr.py \
--nn-model ./sherpa-nemo-ctc-en-citrinet-512/model.pt
--tokens ./sherpa-nemo-ctc-en-citrinet-512/tokens.txt \
--use-gpu false \
--nemo-normalize per_feature \
./sherpa-nemo-ctc-en-citrinet-512/test_wavs/0.wav \
./sherpa-nemo-ctc-en-citrinet-512/test_wavs/1.wav \
./sherpa-nemo-ctc-en-citrinet-512/test_wavs/2.wav
"""
import argparse
import logging
Expand Down Expand Up @@ -121,7 +138,18 @@ def get_parser():
default=True,
help="""If your model was trained using features computed
from samples in the range `[-32768, 32767]`, then please set
this flag to False.
this flag to False. For instance, if you use models from WeNet,
please set it to False.
""",
)

parser.add_argument(
"--nemo-normalize",
type=str,
default="",
help="""Used only for models from NeMo.
Leave it to empty if the preprocessor of the model does not use
normalization. Current supported value is "per_feature".
""",
)

Expand Down Expand Up @@ -283,6 +311,7 @@ def create_recognizer(args):
feat_config.fbank_opts.frame_opts.dither = 0

feat_config.normalize_samples = args.normalize_samples
feat_config.nemo_normalize = args.nemo_normalize

ctc_decoder_config = sherpa.OfflineCtcDecoderConfig(
hlg=args.HLG if args.HLG else "",
Expand Down
1 change: 1 addition & 0 deletions sherpa/cpp_api/bin/offline-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ int main(int argc, char *argv[]) {
<< "The model was trained using training data with sample rate 16000. "
<< "We don't support resample yet";

SHERPA_LOG(INFO) << config.ToString();
sherpa::OfflineRecognizer recognizer(config);

if (use_wav_scp) {
Expand Down
10 changes: 9 additions & 1 deletion sherpa/cpp_api/feature-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,21 @@ void FeatureConfig::Register(ParseOptions *po) {
"true to use samples in the range [-1, 1]. "
"false to use samples in the range [-32768, 32767]. "
"Note: kaldi uses un-normalized samples.");

po->Register(
"nemo-normalize", &nemo_normalize,
"See "
"https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/"
"preprocessing/features.py#L59"
"Current supported value: per_feature or leave it to empty (unset)");
}

std::string FeatureConfig::ToString() const {
std::ostringstream os;
os << "FeatureConfig(";
os << "fbank_opts=" << fbank_opts.ToString() << ", ";
os << "normalize_samples=" << (normalize_samples ? "True" : "False") << ")";
os << "normalize_samples=" << (normalize_samples ? "True" : "False") << ", ";
os << "nemo_normalize=\"" << nemo_normalize << "\")";
return os.str();
}

Expand Down
19 changes: 19 additions & 0 deletions sherpa/cpp_api/feature-config.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,25 @@ struct FeatureConfig {
// If false, we scale the input samples by 32767 inside sherpa
bool normalize_samples = true;

// For Wav2Vec 2.0, we set it to true so that it returns audio samples
// directly.
//
// The user does not need to set it. We set it internally when we
// load a Wav2Vec 2.0 model.
bool return_waveform = false;

// For models from NeMo
// Possible values:
// - per_feature
// - all_features (not implemented yet)
// - fixed_mean (not implemented)
// - fixed_std (not implemented)
// - or just leave it to empty
// See
// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59
// for details
std::string nemo_normalize;

void Register(ParseOptions *po);

/** A string representation for debugging purpose. */
Expand Down
53 changes: 35 additions & 18 deletions sherpa/cpp_api/offline-recognizer-ctc-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "sherpa/csrc/offline-ctc-decoder.h"
#include "sherpa/csrc/offline-ctc-model.h"
#include "sherpa/csrc/offline-ctc-one-best-decoder.h"
#include "sherpa/csrc/offline-nemo-enc-dec-ctc-model-bpe.h"
#include "sherpa/csrc/offline-wav2vec2-ctc-model.h"
#include "sherpa/csrc/offline-wenet-conformer-ctc-model.h"
#include "sherpa/csrc/symbol-table.h"
Expand Down Expand Up @@ -80,26 +81,38 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
// https://github.com/pytorch/audio/blob/main/torchaudio/models/wav2vec2/model.py#L11
model_ =
std::make_unique<OfflineWav2Vec2CtcModel>(config.nn_model, device_);
return_waveform_ = true;
config_.feat_config.return_waveform = true;
symbol_table_.Replace(symbol_table_["|"], " ", "|");
// See Section 4.2 of
// https://arxiv.org/pdf/2006.11477.pdf
config_.feat_config.fbank_opts.frame_opts.frame_shift_ms = 20;
SHERPA_LOG(WARNING) << "Set frame_shift_ms to 20 for wav2vec 2.0";
} else if (class_name == "EncDecCTCModelBPE") {
// This one is from NeMo
// See
// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/ctc_bpe_models.py#L34
//
model_ = std::make_unique<OfflineNeMoEncDecCTCModelBPE>(config.nn_model,
device_);
} else {
std::string s =
"Support only models from icefall, wenet and torchaudio\n"
"https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/"
"conformer_ctc/conformer.py#L27"
"\n"
"https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/"
"asr_model.py#L42"
"\n"
"https://github.com/pytorch/audio/blob/main/torchaudio/models/"
"wav2vec2/model.py#L11"
"\n";

TORCH_CHECK(false, s);
std::ostringstream os;
os << "Support only models from icefall, wenet, torchaudio, and NeMo\n"
"https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/"
"ASR/"
"conformer_ctc/conformer.py#L27"
"\n"
"https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/"
"asr_model.py#L42"
"\n"
"https://github.com/pytorch/audio/blob/main/torchaudio/models/"
"wav2vec2/model.py#L11"
"\n"
"https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/"
"models/ctc_bpe_models.py#L34"
<< "\n"
<< "Given: " << class_name << "\n";

TORCH_CHECK(false, os.str());
}

WarmUp();
Expand All @@ -109,8 +122,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
}

std::unique_ptr<OfflineStream> CreateStream() override {
return std::make_unique<OfflineStream>(
&fbank_, return_waveform_, config_.feat_config.normalize_samples);
return std::make_unique<OfflineStream>(&fbank_, config_.feat_config);
}

void DecodeStreams(OfflineStream **ss, int32_t n) override {
Expand All @@ -125,8 +137,8 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
}

// If return_waveform is false, features_vec contains 2-D tensors of shape
// (num_frames, feature_dim). In this case, we should use the padding value
// -23.
// (num_frames, feature_dim). In this case, we should use the padding
// value -23.
//
// If return_waveform is true, features_vec contains 1-D tensors of shape
// (num_samples,). In this case, we use 0 as the padding value.
Expand All @@ -139,6 +151,11 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
torch::IValue ivalue = model_->Forward(features, features_length);
torch::Tensor log_prob = model_->GetLogSoftmaxOut(ivalue);
torch::Tensor log_prob_len = model_->GetLogSoftmaxOutLength(ivalue);
if (!log_prob_len.defined()) {
log_prob_len =
torch::floor_divide(features_length, model_->SubsamplingFactor());
log_prob_len = log_prob_len.to(log_prob.device());
}

auto results =
decoder_->Decode(log_prob, log_prob_len, model_->SubsamplingFactor());
Expand Down
3 changes: 1 addition & 2 deletions sherpa/cpp_api/offline-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {

std::unique_ptr<OfflineStream> CreateStream() override {
bool return_waveform = false;
return std::make_unique<OfflineStream>(
&fbank_, return_waveform, config_.feat_config.normalize_samples);
return std::make_unique<OfflineStream>(&fbank_, config_.feat_config);
}

void DecodeStreams(OfflineStream **ss, int32_t n) override {
Expand Down
Loading

0 comments on commit 32da448

Please sign in to comment.