Skip to content

Commit

Permalink
support using xnnpack as execution provider (#612)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Feb 28, 2024
1 parent 87a7030 commit 0cb6d1b
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 11 deletions.
2 changes: 1 addition & 1 deletion cmake/onnxruntime-linux-riscv64-static.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ endif()

set(onnxruntime_URL "https://github.com/csukuangfj/onnxruntime-libs/releases/download/v1.18.0/onnxruntime-linux-riscv64-static_lib-1.18.0.zip")
set(onnxruntime_URL2 "https://hub.nuaa.cf/csukuangfj/onnxruntime-libs/releases/download/v1.18.0/onnxruntime-linux-riscv64-static_lib-1.18.0.zip")
set(onnxruntime_HASH "SHA256=6791d695d17118db6815364c975a9d7ea9a8909754516ed1b089fe015c20912e")
set(onnxruntime_HASH "SHA256=77ecc51d8caf0953755db6edcdec2fc03bce3f6d379bedd635be50bb95f88da5")

# If you don't have access to the Internet,
# please download onnxruntime to one of the following locations.
Expand Down
2 changes: 1 addition & 1 deletion cmake/onnxruntime-linux-riscv64.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ endif()

set(onnxruntime_URL "https://github.com/csukuangfj/onnxruntime-libs/releases/download/v1.18.0/onnxruntime-linux-riscv64-1.18.0.zip")
set(onnxruntime_URL2 "https://hub.nuaa.cf/csukuangfj/onnxruntime-libs/releases/download/v1.18.0/onnxruntime-linux-riscv64-1.18.0.zip")
set(onnxruntime_HASH "SHA256=87ef36dbba28ee332069e7e511dcb409913bdeeed231b45172fe200d71c690a2")
set(onnxruntime_HASH "SHA256=81a11b54d1d71f4b3161b00cba8576a07594abd218aa5c0d82382960ada06092")

# If you don't have access to the Internet,
# please download onnxruntime to one of the following locations.
Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/csrc/provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ Provider StringToProvider(std::string s) {
return Provider::kCUDA;
} else if (s == "coreml") {
return Provider::kCoreML;
} else if (s == "xnnpack") {
return Provider::kXnnpack;
} else {
SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str());
return Provider::kCPU;
Expand Down
7 changes: 4 additions & 3 deletions sherpa-onnx/csrc/provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ namespace sherpa_onnx {
// https://github.com/microsoft/onnxruntime/blob/main/java/src/main/java/ai/onnxruntime/OrtProvider.java
// for a list of available providers
enum class Provider {
kCPU = 0, // CPUExecutionProvider
kCUDA = 1, // CUDAExecutionProvider
kCoreML = 2, // CoreMLExecutionProvider
kCPU = 0, // CPUExecutionProvider
kCUDA = 1, // CUDAExecutionProvider
kCoreML = 2, // CoreMLExecutionProvider
kXnnpack = 3, // XnnpackExecutionProvider
};

/**
Expand Down
23 changes: 19 additions & 4 deletions sherpa-onnx/csrc/session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
sess_opts.SetIntraOpNumThreads(num_threads);
sess_opts.SetInterOpNumThreads(num_threads);

std::vector<std::string> available_providers = Ort::GetAvailableProviders();
std::ostringstream os;
for (const auto &ep : available_providers) {
os << ep << ", ";
}

// Other possible options
// sess_opts.SetGraphOptimizationLevel(ORT_ENABLE_EXTENDED);
// sess_opts.SetLogSeverityLevel(ORT_LOGGING_LEVEL_VERBOSE);
Expand All @@ -33,9 +39,17 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
switch (p) {
case Provider::kCPU:
break; // nothing to do for the CPU provider
case Provider::kXnnpack: {
if (std::find(available_providers.begin(), available_providers.end(),
"XnnpackExecutionProvider") != available_providers.end()) {
sess_opts.AppendExecutionProvider("XNNPACK");
} else {
SHERPA_ONNX_LOGE("Available providers: %s. Fallback to cpu!",
os.str().c_str());
}
break;
}
case Provider::kCUDA: {
std::vector<std::string> available_providers =
Ort::GetAvailableProviders();
if (std::find(available_providers.begin(), available_providers.end(),
"CUDAExecutionProvider") != available_providers.end()) {
// The CUDA provider is available, proceed with setting the options
Expand All @@ -47,8 +61,9 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
sess_opts.AppendExecutionProvider_CUDA(options);
} else {
SHERPA_ONNX_LOGE(
"Please compile with -DSHERPA_ONNX_ENABLE_GPU=ON. Fallback to "
"cpu!");
"Please compile with -DSHERPA_ONNX_ENABLE_GPU=ON. Available "
"providers: %s. Fallback to cpu!",
os.str().c_str());
}
break;
}
Expand Down
4 changes: 2 additions & 2 deletions sherpa-onnx/csrc/sherpa-onnx-vad-microphone.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ This program shows how to use VAD in sherpa-onnx.
./bin/sherpa-onnx-vad-microphone \
--silero-vad-model=/path/to/silero_vad.onnx \
--provider=cpu \
--num-threads=1
--vad-provider=cpu \
--vad-num-threads=1
Please download silero_vad.onnx from
https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx
Expand Down

0 comments on commit 0cb6d1b

Please sign in to comment.