Skip to content

Commit

Permalink
[binding] add chunk size interface and use non-streaming decoding by …
Browse files Browse the repository at this point in the history
…default since it's fast and accurate
  • Loading branch information
robin1001 committed Aug 27, 2023
1 parent c39f20b commit ae55630
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 3 deletions.
2 changes: 1 addition & 1 deletion runtime/binding/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ with wave.open(test_wav, 'rb') as fin:
assert fin.getnchannels() == 1
wav = fin.readframes(fin.getnframes())

decoder = wenet.Decoder(lang='chs')
decoder = wenet.Decoder(lang='chs', streaming=True)
# We suppose the wav is 16k, 16bits, and decode every 0.5 seconds
interval = int(0.5 * 16000) * 2
for i in range(0, len(wav), interval):
Expand Down
3 changes: 2 additions & 1 deletion runtime/binding/python/cpp/binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

namespace py = pybind11;


PYBIND11_MODULE(_wenet, m) {
m.doc() = "wenet pybind11 plugin"; // optional module docstring
m.def("wenet_init", &wenet_init, py::return_value_policy::reference,
Expand All @@ -36,4 +35,6 @@ PYBIND11_MODULE(_wenet, m) {
m.def("wenet_set_language", &wenet_set_language, "set language");
m.def("wenet_set_continuous_decoding", &wenet_set_continuous_decoding,
"enable continuous decoding or not");
m.def("wenet_set_chunk_size", &wenet_set_chunk_size,
"set decoding chunk size");
}
9 changes: 8 additions & 1 deletion runtime/binding/python/wenetruntime/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def __init__(self,
enable_timestamp: bool = False,
context: Optional[List[str]] = None,
context_score: float = 3.0,
continuous_decoding: bool = False):
continuous_decoding: bool = False,
streaming: bool = False):
""" Init WeNet decoder
Args:
lang: language type of the model
Expand All @@ -44,6 +45,7 @@ def __init__(self,
context: context words
context_score: bonus score when the context is matched
continuous_decoding: enable countinous decoding or not
streaming: streaming mode
"""
if model_dir is None:
model_dir = Hub.get_model_by_lang(lang)
Expand All @@ -57,6 +59,8 @@ def __init__(self,
self.add_context(context)
self.set_context_score(context_score)
self.set_continuous_decoding(continuous_decoding)
chunk_size = 16 if streaming else -1
self.set_chunk_size(chunk_size)

def __del__(self):
_wenet.wenet_free(self.d)
Expand Down Expand Up @@ -90,6 +94,9 @@ def set_continuous_decoding(self, continuous_decoding: bool):
flag = 1 if continuous_decoding else 0
_wenet.wenet_set_continuous_decoding(self.d, flag)

def set_chunk_size(self, chunk_size: int):
_wenet.wenet_set_chunk_size(self.d, chunk_size)

def decode(self,
audio: Union[str, bytes, np.ndarray],
last: bool = True) -> str:
Expand Down
13 changes: 13 additions & 0 deletions runtime/binding/python/wenetruntime/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import argparse

from wenetruntime.decoder import Decoder
from _wenet import wenet_set_log_level as set_log_level # noqa


def get_args():
Expand All @@ -23,14 +24,26 @@ def get_args():
default='chs',
choices=['chs', 'en'],
help='select language')
parser.add_argument('-c',
'--chunk_size',
default=-1,
type=int,
help='set decoding chunk size')
parser.add_argument('-v',
'--verbose',
default=0,
type=int,
help='set log(glog backend) level')
parser.add_argument('audio', help='input audio file')
args = parser.parse_args()
return args


def main():
args = get_args()
set_log_level(args.verbose)
decoder = Decoder(lang=args.language)
decoder.set_chunk_size(args.chunk_size)
result = decoder.decode(args.audio)
print(result)

Expand Down
9 changes: 9 additions & 0 deletions runtime/core/api/wenet_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ class Recognizer {
}
resource_->post_processor =
std::make_shared<wenet::PostProcessor>(*post_process_opts_);
// Init decode options
decode_options_->chunk_size = chunk_size_;
// Init decoder
decoder_ = std::make_shared<wenet::AsrDecoder>(feature_pipeline_, resource_,
*decode_options_);
Expand Down Expand Up @@ -180,6 +182,7 @@ class Recognizer {
void set_context_score(float score) { context_score_ = score; }
void set_language(const char* lang) { language_ = lang; }
void set_continuous_decoding(bool flag) { continuous_decoding_ = flag; }
void set_chunk_size(int chunk_size) { chunk_size_ = chunk_size; }

private:
// NOTE(Binbin Zhang): All use shared_ptr for clone in the future
Expand All @@ -197,6 +200,7 @@ class Recognizer {
float context_score_;
std::string language_ = "chs";
bool continuous_decoding_ = false;
int chunk_size_ = 16;
};

void* wenet_init(const char* model_dir) {
Expand Down Expand Up @@ -255,3 +259,8 @@ void wenet_set_continuous_decoding(void* decoder, int flag) {
Recognizer* recognizer = reinterpret_cast<Recognizer*>(decoder);
recognizer->set_continuous_decoding(flag > 0);
}

void wenet_set_chunk_size(void* decoder, int chunk_size) {
Recognizer* recognizer = reinterpret_cast<Recognizer*>(decoder);
recognizer->set_chunk_size(chunk_size);
}
4 changes: 4 additions & 0 deletions runtime/core/api/wenet_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ void wenet_set_log_level(int level);
*/
void wenet_set_continuous_decoding(void* decoder, int flag);

/** Set chunk size for decoding, -1 for non-streaming decoding
*/
void wenet_set_chunk_size(void* decoder, int chunk_size);

#ifdef __cplusplus
}
#endif
Expand Down

0 comments on commit ae55630

Please sign in to comment.