Skip to content

Commit

Permalink
Add tail_paddings to Whisper C API. (#886)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored May 17, 2024
1 parent 65635b0 commit 8af2af8
Show file tree
Hide file tree
Showing 13 changed files with 29 additions and 9 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
project(sherpa-onnx)

set(SHERPA_ONNX_VERSION "1.9.24")
set(SHERPA_ONNX_VERSION "1.9.25")

# Disable warning about
#
Expand Down
1 change: 1 addition & 0 deletions nodejs-examples/test-offline-nemo-ctc.js
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ function createOfflineRecognizer() {
decoder: '',
language: '',
task: '',
tailPaddings: -1,
},
tdnn: {
model: '',
Expand Down
1 change: 1 addition & 0 deletions nodejs-examples/test-offline-paraformer.js
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ function createOfflineRecognizer() {
decoder: '',
language: '',
task: '',
tailPaddings: -1,
},
tdnn: {
model: '',
Expand Down
1 change: 1 addition & 0 deletions nodejs-examples/test-offline-transducer.js
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ function createOfflineRecognizer() {
decoder: '',
language: '',
task: '',
tailPaddings: -1,
},
tdnn: {
model: '',
Expand Down
1 change: 1 addition & 0 deletions nodejs-examples/test-offline-whisper.js
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ function createOfflineRecognizer() {
decoder: './sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx',
language: '',
task: 'transcribe',
tailPaddings: -1,
},
tdnn: {
model: '',
Expand Down
3 changes: 3 additions & 0 deletions scripts/dotnet/offline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ public OfflineWhisperModelConfig()
Decoder = "";
Language = "";
Task = "transcribe";
TailPaddings = -1;
}
[MarshalAs(UnmanagedType.LPStr)]
public string Encoder;
Expand All @@ -313,6 +314,8 @@ public OfflineWhisperModelConfig()

[MarshalAs(UnmanagedType.LPStr)]
public string Task;

public int TailPaddings;
}

[StructLayout(LayoutKind.Sequential)]
Expand Down
11 changes: 7 additions & 4 deletions scripts/go/sherpa_onnx.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,10 +336,11 @@ type OfflineNemoEncDecCtcModelConfig struct {
}

type OfflineWhisperModelConfig struct {
Encoder string
Decoder string
Language string
Task string
Encoder string
Decoder string
Language string
Task string
TailPaddings int
}

type OfflineTdnnModelConfig struct {
Expand Down Expand Up @@ -441,6 +442,8 @@ func NewOfflineRecognizer(config *OfflineRecognizerConfig) *OfflineRecognizer {
c.model_config.whisper.task = C.CString(config.ModelConfig.Whisper.Task)
defer C.free(unsafe.Pointer(c.model_config.whisper.task))

c.model_config.whisper.tail_paddings = C.int(config.ModelConfig.Whisper.TailPaddings)

c.model_config.tdnn.model = C.CString(config.ModelConfig.Tdnn.Model)
defer C.free(unsafe.Pointer(c.model_config.tdnn.model))

Expand Down
3 changes: 2 additions & 1 deletion scripts/node-addon-api/src/non-streaming-asr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ static SherpaOnnxOfflineWhisperModelConfig GetOfflineWhisperModelConfig(
SHERPA_ONNX_ASSIGN_ATTR_STR(encoder, encoder);
SHERPA_ONNX_ASSIGN_ATTR_STR(decoder, decoder);
SHERPA_ONNX_ASSIGN_ATTR_STR(language, language);
SHERPA_ONNX_ASSIGN_ATTR_STR(task, languagek);
SHERPA_ONNX_ASSIGN_ATTR_STR(task, task);
SHERPA_ONNX_ASSIGN_ATTR_INT32(tail_paddings, tailPaddings);

return c;
}
Expand Down
3 changes: 3 additions & 0 deletions sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
recognizer_config.model_config.whisper.task = "transcribe";
}

recognizer_config.model_config.whisper.tail_paddings =
SHERPA_ONNX_OR(config->model_config.whisper.tail_paddings, -1);

recognizer_config.model_config.tdnn.model =
SHERPA_ONNX_OR(config->model_config.tdnn.model, "");

Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/c-api/c-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineWhisperModelConfig {
const char *decoder;
const char *language;
const char *task;
int32_t tail_paddings;
} SherpaOnnxOfflineWhisperModelConfig;

SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTdnnModelConfig {
Expand Down
6 changes: 4 additions & 2 deletions swift-api-examples/SherpaOnnx.swift
Original file line number Diff line number Diff line change
Expand Up @@ -314,13 +314,15 @@ func sherpaOnnxOfflineWhisperModelConfig(
encoder: String = "",
decoder: String = "",
language: String = "",
task: String = "transcribe"
task: String = "transcribe",
tailPaddings: Int = -1
) -> SherpaOnnxOfflineWhisperModelConfig {
return SherpaOnnxOfflineWhisperModelConfig(
encoder: toCPointer(encoder),
decoder: toCPointer(decoder),
language: toCPointer(language),
task: toCPointer(task)
task: toCPointer(task),
tail_paddings: Int32(tailPaddings)
)
}

Expand Down
2 changes: 2 additions & 0 deletions wasm/asr/sherpa-onnx-asr.js
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,8 @@ function initSherpaOnnxOfflineWhisperModelConfig(config, Module) {
Module.setValue(ptr + 12, buffer + offset, 'i8*');
offset += taskLen;

Module.setValue(ptr + 16, config.tailPaddings || -1, 'i32');

return {
buffer: buffer, ptr: ptr, len: len,
}
Expand Down
3 changes: 2 additions & 1 deletion wasm/nodejs/sherpa-onnx-wasm-nodejs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ static_assert(sizeof(SherpaOnnxOfflineTransducerModelConfig) == 3 * 4, "");
static_assert(sizeof(SherpaOnnxOfflineParaformerModelConfig) == 4, "");

static_assert(sizeof(SherpaOnnxOfflineNemoEncDecCtcModelConfig) == 4, "");
static_assert(sizeof(SherpaOnnxOfflineWhisperModelConfig) == 4 * 4, "");
static_assert(sizeof(SherpaOnnxOfflineWhisperModelConfig) == 5 * 4, "");
static_assert(sizeof(SherpaOnnxOfflineTdnnModelConfig) == 4, "");
static_assert(sizeof(SherpaOnnxOfflineLMConfig) == 2 * 4, "");

Expand Down Expand Up @@ -80,6 +80,7 @@ void PrintOfflineRecognizerConfig(SherpaOnnxOfflineRecognizerConfig *config) {
fprintf(stdout, "decoder: %s\n", whisper->decoder);
fprintf(stdout, "language: %s\n", whisper->language);
fprintf(stdout, "task: %s\n", whisper->task);
fprintf(stdout, "tail_paddings: %d\n", whisper->tail_paddings);

fprintf(stdout, "----------offline tdnn model config----------\n");
fprintf(stdout, "model: %s\n", tdnn->model);
Expand Down

0 comments on commit 8af2af8

Please sign in to comment.