diff --git a/.cargo/config.toml b/.cargo/config.toml
index 64fcc6aea..f015cd17a 100644
--- a/.cargo/config.toml
+++ b/.cargo/config.toml
@@ -4,10 +4,6 @@ xtask = "run -p xtask --"
[env]
CARGO_WORKSPACE_DIR = { value = "", relative = true }
-# Windows環境でテストエラーになるのを防ぐために設定するworkaround
-# https://github.com/VOICEVOX/onnxruntime-rs/issues/3#issuecomment-1207381367
-ORT_OUT_DIR = { value = "target/debug/deps", relative = true }
-
[target.aarch64-unknown-linux-gnu]
linker = "aarch64-linux-gnu-gcc"
diff --git a/.github/workflows/build_and_deploy.yml b/.github/workflows/build_and_deploy.yml
index 83d895c77..01ddfc434 100644
--- a/.github/workflows/build_and_deploy.yml
+++ b/.github/workflows/build_and_deploy.yml
@@ -58,7 +58,6 @@ jobs:
"target": "x86_64-pc-windows-msvc",
"artifact_name": "windows-x64-cpu",
"whl_local_version": "cpu",
- "use_cuda": false,
"can_skip_in_simple_test": true
},
{
@@ -67,16 +66,14 @@ jobs:
"target": "x86_64-pc-windows-msvc",
"artifact_name": "windows-x64-directml",
"whl_local_version": "directml",
- "use_cuda": false,
"can_skip_in_simple_test": false
},
{
"os": "windows-2019",
- "features": "",
+ "features": "cuda",
"target": "x86_64-pc-windows-msvc",
"artifact_name": "windows-x64-cuda",
"whl_local_version": "cuda",
- "use_cuda": true,
"can_skip_in_simple_test": true
},
{
@@ -85,7 +82,6 @@ jobs:
"target": "i686-pc-windows-msvc",
"artifact_name": "windows-x86-cpu",
"whl_local_version": "cpu",
- "use_cuda": false,
"can_skip_in_simple_test": true
},
{
@@ -94,16 +90,14 @@ jobs:
"target": "x86_64-unknown-linux-gnu",
"artifact_name": "linux-x64-cpu",
"whl_local_version": "cpu",
- "use_cuda": false,
"can_skip_in_simple_test": true
},
{
"os": "ubuntu-20.04",
- "features": "",
+ "features": "cuda",
"target": "x86_64-unknown-linux-gnu",
"artifact_name": "linux-x64-gpu",
"whl_local_version": "cuda",
- "use_cuda": true,
"can_skip_in_simple_test": false
},
{
@@ -112,7 +106,6 @@ jobs:
"target": "aarch64-unknown-linux-gnu",
"artifact_name": "linux-arm64-cpu",
"whl_local_version": "cpu",
- "use_cuda": false,
"can_skip_in_simple_test": true
},
{
@@ -120,7 +113,6 @@ jobs:
"features": "",
"target": "aarch64-linux-android",
"artifact_name": "android-arm64-cpu",
- "use_cuda": false,
"can_skip_in_simple_test": true
},
{
@@ -128,7 +120,6 @@ jobs:
"features": "",
"target": "x86_64-linux-android",
"artifact_name": "android-x86_64-cpu",
- "use_cuda": false,
"can_skip_in_simple_test": true
},
{
@@ -137,7 +128,6 @@ jobs:
"target": "aarch64-apple-darwin",
"artifact_name": "osx-arm64-cpu",
"whl_local_version": "cpu",
- "use_cuda": false,
"can_skip_in_simple_test": false
},
{
@@ -146,7 +136,6 @@ jobs:
"target": "x86_64-apple-darwin",
"artifact_name": "osx-x64-cpu",
"whl_local_version": "cpu",
- "use_cuda": false,
"can_skip_in_simple_test": true
},
{
@@ -154,7 +143,6 @@ jobs:
"features": "",
"target": "aarch64-apple-ios",
"artifact_name": "ios-arm64-cpu",
- "use_cuda": false,
"can_skip_in_simple_test": true
},
{
@@ -162,7 +150,6 @@ jobs:
"features": "",
"target": "aarch64-apple-ios-sim",
"artifact_name": "ios-arm64-cpu-sim",
- "use_cuda": false,
"can_skip_in_simple_test": true
},
{
@@ -170,7 +157,6 @@ jobs:
"features": "",
"target": "x86_64-apple-ios",
"artifact_name": "ios-x64-cpu",
- "use_cuda": false,
"can_skip_in_simple_test": true
}
]'
@@ -268,7 +254,6 @@ jobs:
fi
env:
RUSTFLAGS: -C panic=abort
- ORT_USE_CUDA: ${{ matrix.use_cuda }}
- name: build voicevox_core_python_api
if: matrix.whl_local_version
id: build-voicevox-core-python-api
@@ -286,8 +271,6 @@ jobs:
build > /dev/null 2>&1
fi
echo "whl=$(find ./target/wheels -type f)" >> "$GITHUB_OUTPUT"
- env:
- ORT_USE_CUDA: ${{ matrix.use_cuda }}
- name: build voicevox_core_java_api
if: contains(matrix.target, 'android')
run: |
@@ -305,7 +288,7 @@ jobs:
cp -v crates/voicevox_core_c_api/include/voicevox_core.h "artifact/${{ env.ASSET_NAME }}"
cp -v target/${{ matrix.target }}/release/*voicevox_core.{dll,so,dylib} "artifact/${{ env.ASSET_NAME }}" || true
cp -v target/${{ matrix.target }}/release/voicevox_core.dll.lib "artifact/${{ env.ASSET_NAME }}/voicevox_core.lib" || true
- cp -v -n target/${{ matrix.target }}/release/build/onnxruntime-sys-*/out/onnxruntime_*/onnxruntime-*/lib/*.{dll,so.*,so,dylib} "artifact/${{ env.ASSET_NAME }}" || true
+ cp -v -n target/${{ matrix.target }}/release/{,lib}onnxruntime*.{dll,so.*,so,dylib} "artifact/${{ env.ASSET_NAME }}" || true
# libonnxruntimeについてはバージョン付のshared libraryを使用するためバージョンがついてないものを削除する
rm -f artifact/${{ env.ASSET_NAME }}/libonnxruntime.{so,dylib}
cp -v README.md "artifact/${{ env.ASSET_NAME }}/README.txt"
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 0496eebdb..d0639d045 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -72,8 +72,8 @@ jobs:
with:
python-version: "3.8"
- uses: Swatinem/rust-cache@v2
- - run: cargo clippy -vv --all-features --features onnxruntime/disable-sys-build-script --tests -- -D clippy::all -D warnings --no-deps
- - run: cargo clippy -vv --all-features --features onnxruntime/disable-sys-build-script -- -D clippy::all -D warnings --no-deps
+ - run: cargo clippy -vv --all-features --tests -- -D clippy::all -D warnings --no-deps
+ - run: cargo clippy -vv --all-features -- -D clippy::all -D warnings --no-deps
- run: cargo fmt -- --check
rust-unit-test:
@@ -199,8 +199,8 @@ jobs:
mkdir -p example/cpp/unix/voicevox_core/
cp -v crates/voicevox_core_c_api/include/voicevox_core.h example/cpp/unix/voicevox_core/
cp -v target/debug/libvoicevox_core.{so,dylib} example/cpp/unix/voicevox_core/ || true
- cp -v target/debug/build/onnxruntime-sys-*/out/onnxruntime_*/onnxruntime-*/lib/libonnxruntime.so.* example/cpp/unix/voicevox_core/ || true
- cp -v target/debug/build/onnxruntime-sys-*/out/onnxruntime_*/onnxruntime-*/lib/libonnxruntime.*.dylib example/cpp/unix/voicevox_core/ || true
+ cp -v target/debug/libonnxruntime.so.* example/cpp/unix/voicevox_core/ || true
+ cp -v target/debug/libonnxruntime.*.dylib example/cpp/unix/voicevox_core/ || true
- if: startsWith(matrix.os, 'mac')
uses: jwlawson/actions-setup-cmake@v1.13
@@ -281,9 +281,9 @@ jobs:
- run: poetry run maturin develop --locked
- name: 必要なDLLをコピーしてpytestを実行
run: |
- cp -v ../../target/debug/build/onnxruntime-sys-*/out/onnxruntime_*/onnxruntime-*/lib/onnxruntime.dll . || true
- cp -v ../../target/debug/build/onnxruntime-sys-*/out/onnxruntime_*/onnxruntime-*/lib/libonnxruntime.so.* . || true
- cp -v ../../target/debug/build/onnxruntime-sys-*/out/onnxruntime_*/onnxruntime-*/lib/libonnxruntime.*.dylib . || true
+ cp -v ../../target/debug/onnxruntime.dll . || true
+ cp -v ../../target/debug/libonnxruntime.so.* . || true
+ cp -v ../../target/debug/libonnxruntime.*.dylib . || true
poetry run pytest
- name: Exampleを実行
diff --git a/Cargo.lock b/Cargo.lock
index 4015b7e36..0f9a78469 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -426,12 +426,6 @@ dependencies = [
"winapi",
]
-[[package]]
-name = "chunked_transfer"
-version = "1.4.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "fff857943da45f546682664a79488be82e69e43c1a7a2307679ab9afb3a66d2e"
-
[[package]]
name = "cipher"
version = "0.3.0"
@@ -684,6 +678,12 @@ dependencies = [
"cfg-if",
]
+[[package]]
+name = "crunchy"
+version = "0.2.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
+
[[package]]
name = "crypto-common"
version = "0.1.6"
@@ -1278,6 +1278,16 @@ dependencies = [
"tracing",
]
+[[package]]
+name = "half"
+version = "2.3.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "bc52e53916c08643f1b56ec082790d1e86a32e58dc5268f897f313fbae7b4872"
+dependencies = [
+ "cfg-if",
+ "crunchy",
+]
+
[[package]]
name = "hashbrown"
version = "0.12.3"
@@ -1410,7 +1420,7 @@ checksum = "1788965e61b367cd03a62950836d5cd41560c3577d90e40e0819373194d1661c"
dependencies = [
"http",
"hyper",
- "rustls",
+ "rustls 0.20.6",
"tokio",
"tokio-rustls",
]
@@ -1993,30 +2003,6 @@ version = "1.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d"
-[[package]]
-name = "onnxruntime"
-version = "0.1.0"
-source = "git+https://github.com/VOICEVOX/onnxruntime-rs.git?rev=ebb9dcb9b26ee681889b52b6db3b4f642b04a250#ebb9dcb9b26ee681889b52b6db3b4f642b04a250"
-dependencies = [
- "lazy_static",
- "ndarray",
- "onnxruntime-sys",
- "thiserror",
- "tracing",
-]
-
-[[package]]
-name = "onnxruntime-sys"
-version = "0.0.25"
-source = "git+https://github.com/VOICEVOX/onnxruntime-rs.git?rev=ebb9dcb9b26ee681889b52b6db3b4f642b04a250#ebb9dcb9b26ee681889b52b6db3b4f642b04a250"
-dependencies = [
- "flate2",
- "once_cell",
- "tar",
- "ureq",
- "zip",
-]
-
[[package]]
name = "opaque-debug"
version = "0.3.0"
@@ -2598,7 +2584,7 @@ dependencies = [
"once_cell",
"percent-encoding",
"pin-project-lite",
- "rustls",
+ "rustls 0.20.6",
"rustls-pemfile",
"serde",
"serde_json",
@@ -2611,7 +2597,7 @@ dependencies = [
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
- "webpki-roots",
+ "webpki-roots 0.22.5",
"winreg",
]
@@ -2728,6 +2714,18 @@ dependencies = [
"webpki",
]
+[[package]]
+name = "rustls"
+version = "0.21.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "cd8d6c9f025a446bc4d18ad9632e69aec8f287aa84499ee335599fabd20c3fd8"
+dependencies = [
+ "log",
+ "ring",
+ "rustls-webpki",
+ "sct",
+]
+
[[package]]
name = "rustls-pemfile"
version = "1.0.2"
@@ -2737,6 +2735,16 @@ dependencies = [
"base64 0.21.0",
]
+[[package]]
+name = "rustls-webpki"
+version = "0.101.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3c7d5dece342910d9ba34d259310cae3e0154b873b35408b787b59bce53d34fe"
+dependencies = [
+ "ring",
+ "untrusted",
+]
+
[[package]]
name = "rustversion"
version = "1.0.11"
@@ -3262,7 +3270,7 @@ version = "0.23.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c43ee83903113e03984cb9e5cebe6c04a5116269e900e3ddba8f068a62adda59"
dependencies = [
- "rustls",
+ "rustls 0.20.6",
"tokio",
"webpki",
]
@@ -3480,19 +3488,17 @@ checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
[[package]]
name = "ureq"
-version = "2.5.0"
+version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b97acb4c28a254fd7a4aeec976c46a7fa404eac4d7c134b30c75144846d7cb8f"
+checksum = "f5ccd538d4a604753ebc2f17cd9946e89b77bf87f6a8e2309667c6f2e87855e3"
dependencies = [
- "base64 0.13.0",
- "chunked_transfer",
- "flate2",
+ "base64 0.21.0",
"log",
"once_cell",
- "rustls",
+ "rustls 0.21.7",
+ "rustls-webpki",
"url",
- "webpki",
- "webpki-roots",
+ "webpki-roots 0.25.4",
]
[[package]]
@@ -3535,6 +3541,31 @@ version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
+[[package]]
+name = "voicevox-ort"
+version = "2.0.0-rc.2"
+source = "git+https://github.com/VOICEVOX/ort.git?rev=a2d6ae22327869e896bf4c16828734d09516d2d9#a2d6ae22327869e896bf4c16828734d09516d2d9"
+dependencies = [
+ "half",
+ "js-sys",
+ "ndarray",
+ "thiserror",
+ "tracing",
+ "voicevox-ort-sys",
+ "web-sys",
+]
+
+[[package]]
+name = "voicevox-ort-sys"
+version = "2.0.0-rc.2"
+source = "git+https://github.com/VOICEVOX/ort.git?rev=a2d6ae22327869e896bf4c16828734d09516d2d9#a2d6ae22327869e896bf4c16828734d09516d2d9"
+dependencies = [
+ "flate2",
+ "sha2",
+ "tar",
+ "ureq",
+]
+
[[package]]
name = "voicevox_core"
version = "0.0.0"
@@ -3559,7 +3590,6 @@ dependencies = [
"nanoid",
"ndarray",
"once_cell",
- "onnxruntime",
"open_jtalk",
"ouroboros",
"pretty_assertions",
@@ -3578,6 +3608,7 @@ dependencies = [
"tokio",
"tracing",
"uuid",
+ "voicevox-ort",
"voicevox_core_macros",
"windows",
"zip",
@@ -3797,6 +3828,12 @@ dependencies = [
"webpki",
]
+[[package]]
+name = "webpki-roots"
+version = "0.25.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1"
+
[[package]]
name = "which"
version = "4.3.0"
diff --git a/Cargo.toml b/Cargo.toml
index 69d924a16..7102127d5 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -86,9 +86,9 @@ voicevox_core = { path = "crates/voicevox_core" }
windows = "0.43.0"
zip = "0.6.3"
-[workspace.dependencies.onnxruntime]
-git = "https://github.com/VOICEVOX/onnxruntime-rs.git"
-rev = "ebb9dcb9b26ee681889b52b6db3b4f642b04a250"
+[workspace.dependencies.voicevox-ort]
+git = "https://github.com/VOICEVOX/ort.git"
+rev = "a2d6ae22327869e896bf4c16828734d09516d2d9"
[workspace.dependencies.open_jtalk]
git = "https://github.com/VOICEVOX/open_jtalk-rs.git"
diff --git a/crates/voicevox_core/Cargo.toml b/crates/voicevox_core/Cargo.toml
index 5c8dd440d..731862356 100644
--- a/crates/voicevox_core/Cargo.toml
+++ b/crates/voicevox_core/Cargo.toml
@@ -6,7 +6,8 @@ publish.workspace = true
[features]
default = []
-directml = ["onnxruntime/directml"]
+cuda = ["voicevox-ort/cuda"]
+directml = ["voicevox-ort/directml"]
[dependencies]
anyhow.workspace = true
@@ -27,7 +28,6 @@ jlabel.workspace = true
nanoid.workspace = true
ndarray.workspace = true
once_cell.workspace = true
-onnxruntime.workspace = true
open_jtalk.workspace = true
ouroboros.workspace = true
rayon.workspace = true
@@ -43,6 +43,7 @@ tokio = { workspace = true, features = ["rt"] } # FIXME: feature-gateする
tracing.workspace = true
uuid = { workspace = true, features = ["v4", "serde"] }
voicevox_core_macros = { path = "../voicevox_core_macros" }
+voicevox-ort = { workspace = true, features = ["ndarray", "download-binaries"] }
zip.workspace = true
[dev-dependencies]
diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs
index fc8954e7d..c2cad1d7d 100644
--- a/crates/voicevox_core/src/infer.rs
+++ b/crates/voicevox_core/src/infer.rs
@@ -79,16 +79,20 @@ pub(crate) trait InferenceSignature: Sized + Send + 'static {
pub(crate) trait InferenceInputSignature: Send + 'static {
type Signature: InferenceSignature;
const PARAM_INFOS: &'static [ParamInfo];
- fn make_run_context(self, sess: &mut R::Session) -> R::RunContext<'_>;
+ fn make_run_context(
+ self,
+ sess: &mut R::Session,
+ ) -> anyhow::Result>;
}
pub(crate) trait InputScalar: Sized {
const KIND: InputScalarKind;
+ // TODO: `Array`ではなく`ArrayView`を取ることができるかもしれない
fn push_tensor_to_ctx(
tensor: Array,
visitor: &mut impl PushInputTensor,
- );
+ ) -> anyhow::Result<()>;
}
#[duplicate_item(
@@ -102,8 +106,8 @@ impl InputScalar for T {
fn push_tensor_to_ctx(
tensor: Array,
ctx: &mut impl PushInputTensor,
- ) {
- ctx.push(tensor);
+ ) -> anyhow::Result<()> {
+ ctx.push(tensor)
}
}
@@ -117,8 +121,8 @@ pub(crate) enum InputScalarKind {
}
pub(crate) trait PushInputTensor {
- fn push_int64(&mut self, tensor: Array);
- fn push_float32(&mut self, tensor: Array);
+ fn push_int64(&mut self, tensor: Array) -> anyhow::Result<()>;
+ fn push_float32(&mut self, tensor: Array) -> anyhow::Result<()>;
}
/// 推論操作の出力シグネチャ。
diff --git a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs
index 0556b6d51..f8f376837 100644
--- a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs
+++ b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs
@@ -1,18 +1,15 @@
use std::{fmt::Debug, vec};
-use anyhow::anyhow;
+use anyhow::{anyhow, bail, ensure};
use duplicate::duplicate_item;
use ndarray::{Array, Dimension};
-use once_cell::sync::Lazy;
-use onnxruntime::{
- environment::Environment, GraphOptimizationLevel, LoggingLevel, TensorElementDataType,
- TypeToTensorElementDataType,
+use ort::{
+ CPUExecutionProvider, CUDAExecutionProvider, DirectMLExecutionProvider, ExecutionProvider as _,
+ GraphOptimizationLevel, IntoTensorElementType, TensorElementType, ValueType,
};
use crate::{devices::SupportedDevices, error::ErrorRepr};
-use self::assert_send::AssertSend;
-
use super::super::{
DecryptModelError, InferenceRuntime, InferenceSessionOptions, InputScalarKind,
OutputScalarKind, OutputTensor, ParamInfo, PushInputTensor,
@@ -22,29 +19,28 @@ use super::super::{
pub(crate) enum Onnxruntime {}
impl InferenceRuntime for Onnxruntime {
- type Session = AssertSend>;
+ type Session = ort::Session;
type RunContext<'a> = OnnxruntimeRunContext<'a>;
fn supported_devices() -> crate::Result {
- let mut cuda_support = false;
- let mut dml_support = false;
- for provider in onnxruntime::session::get_available_providers()
- .map_err(Into::into)
- .map_err(ErrorRepr::GetSupportedDevices)?
- .iter()
- {
- match provider.as_str() {
- "CUDAExecutionProvider" => cuda_support = true,
- "DmlExecutionProvider" => dml_support = true,
- _ => {}
- }
- }
+ // TODO: `InferenceRuntime::init`と`InitInferenceRuntimeError`を作る
+ build_ort_env_once().unwrap();
+
+ (|| {
+ let cpu = CPUExecutionProvider::default().is_available()?;
+ let cuda = CUDAExecutionProvider::default().is_available()?;
+ let dml = DirectMLExecutionProvider::default().is_available()?;
- Ok(SupportedDevices {
- cpu: true,
- cuda: cuda_support,
- dml: dml_support,
- })
+ ensure!(cpu, "missing `CPUExecutionProvider`");
+
+ Ok(SupportedDevices {
+ cpu: true,
+ cuda,
+ dml,
+ })
+ })()
+ .map_err(ErrorRepr::GetSupportedDevices)
+ .map_err(Into::into)
}
fn new_session(
@@ -55,48 +51,52 @@ impl InferenceRuntime for Onnxruntime {
Vec>,
Vec>,
)> {
- let mut builder = ENVIRONMENT
- .new_session_builder()?
- .with_optimization_level(GraphOptimizationLevel::Basic)?
- .with_intra_op_num_threads(options.cpu_num_threads.into())?
- .with_inter_op_num_threads(options.cpu_num_threads.into())?;
-
- if options.use_gpu {
- #[cfg(feature = "directml")]
- {
- use onnxruntime::ExecutionMode;
-
- builder = builder
- .with_disable_mem_pattern()?
- .with_execution_mode(ExecutionMode::ORT_SEQUENTIAL)?
- .with_append_execution_provider_directml(0)?;
- }
-
- #[cfg(not(feature = "directml"))]
- {
- builder = builder.with_append_execution_provider_cuda(Default::default())?;
- }
+ // TODO: `InferenceRuntime::init`と`InitInferenceRuntimeError`を作る
+ build_ort_env_once().unwrap();
+
+ let mut builder = ort::Session::builder()?
+ .with_optimization_level(GraphOptimizationLevel::Level1)?
+ .with_intra_threads(options.cpu_num_threads.into())?;
+
+ if options.use_gpu && cfg!(feature = "directml") {
+ builder = builder
+ .with_parallel_execution(false)?
+ .with_memory_pattern(false)?;
+ DirectMLExecutionProvider::default().register(&builder)?;
+ } else if options.use_gpu && cfg!(feature = "cuda") {
+ CUDAExecutionProvider::default().register(&builder)?;
}
let model = model()?;
- let sess = AssertSend::from(builder.with_model_from_memory(model)?);
+ let sess = builder.commit_from_memory(&{ model })?;
let input_param_infos = sess
.inputs
.iter()
.map(|info| {
- let dt = match info.input_type {
- TensorElementDataType::Float => Ok(InputScalarKind::Float32),
- TensorElementDataType::Uint8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8"),
- TensorElementDataType::Int8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8"),
- TensorElementDataType::Uint16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16"),
- TensorElementDataType::Int16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16"),
- TensorElementDataType::Int32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32"),
- TensorElementDataType::Int64 => Ok(InputScalarKind::Int64),
- TensorElementDataType::String => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING"),
- TensorElementDataType::Double => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE"),
- TensorElementDataType::Uint32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32"),
- TensorElementDataType::Uint64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64"),
+ let ValueType::Tensor { ty, .. } = info.input_type else {
+ bail!(
+ "unexpected input value type for `{}`. currently `ONNX_TYPE_TENSOR` and \
+ `ONNX_TYPE_SPARSETENSOR` is supported",
+ info.name,
+ );
+ };
+
+ let dt = match ty {
+ TensorElementType::Float32 => Ok(InputScalarKind::Float32),
+ TensorElementType::Uint8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8"),
+ TensorElementType::Int8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8"),
+ TensorElementType::Uint16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16"),
+ TensorElementType::Int16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16"),
+ TensorElementType::Int32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32"),
+ TensorElementType::Int64 => Ok(InputScalarKind::Int64),
+ TensorElementType::String => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING"),
+ TensorElementType::Bfloat16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16"),
+ TensorElementType::Float16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16"),
+ TensorElementType::Float64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE"),
+ TensorElementType::Uint32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32"),
+ TensorElementType::Uint64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64"),
+ TensorElementType::Bool => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL"),
}
.map_err(|actual| {
anyhow!("unsupported input datatype `{actual}` for `{}`", info.name)
@@ -105,7 +105,7 @@ impl InferenceRuntime for Onnxruntime {
Ok(ParamInfo {
name: info.name.clone().into(),
dt,
- ndim: Some(info.dimensions.len()),
+ ndim: info.input_type.tensor_dimensions().map(Vec::len),
})
})
.collect::>()?;
@@ -114,18 +114,29 @@ impl InferenceRuntime for Onnxruntime {
.outputs
.iter()
.map(|info| {
- let dt = match info.output_type {
- TensorElementDataType::Float => Ok(OutputScalarKind::Float32),
- TensorElementDataType::Uint8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8"),
- TensorElementDataType::Int8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8"),
- TensorElementDataType::Uint16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16"),
- TensorElementDataType::Int16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16"),
- TensorElementDataType::Int32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32"),
- TensorElementDataType::Int64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64"),
- TensorElementDataType::String => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING"),
- TensorElementDataType::Double => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE"),
- TensorElementDataType::Uint32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32"),
- TensorElementDataType::Uint64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64"),
+ let ValueType::Tensor { ty, .. } = info.output_type else {
+ bail!(
+ "unexpected output value type for `{}`. currently `ONNX_TYPE_TENSOR` and \
+ `ONNX_TYPE_SPARSETENSOR` is supported",
+ info.name,
+ );
+ };
+
+ let dt = match ty {
+ TensorElementType::Float32 => Ok(OutputScalarKind::Float32),
+ TensorElementType::Uint8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8"),
+ TensorElementType::Int8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8"),
+ TensorElementType::Uint16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16"),
+ TensorElementType::Int16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16"),
+ TensorElementType::Int32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32"),
+ TensorElementType::Int64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64"),
+ TensorElementType::String => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING"),
+ TensorElementType::Bfloat16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16"),
+ TensorElementType::Float16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16"),
+ TensorElementType::Float64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE"),
+ TensorElementType::Uint32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32"),
+ TensorElementType::Uint64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64"),
+ TensorElementType::Bool => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL"),
}
.map_err(|actual| {
anyhow!("unsupported output datatype `{actual}` for `{}`", info.name)
@@ -134,73 +145,69 @@ impl InferenceRuntime for Onnxruntime {
Ok(ParamInfo {
name: info.name.clone().into(),
dt,
- ndim: Some(info.dimensions.len()),
+ ndim: info.output_type.tensor_dimensions().map(|d| d.len()),
})
})
.collect::>()?;
- return Ok((sess, input_param_infos, output_param_infos));
-
- static ENVIRONMENT: Lazy = Lazy::new(|| {
- Environment::builder()
- .with_name(env!("CARGO_PKG_NAME"))
- .with_log_level(LOGGING_LEVEL)
- .build()
- .unwrap()
- });
-
- const LOGGING_LEVEL: LoggingLevel = if cfg!(debug_assertions) {
- LoggingLevel::Verbose
- } else {
- LoggingLevel::Warning
- };
+ Ok((sess, input_param_infos, output_param_infos))
}
fn run(
- OnnxruntimeRunContext { sess, mut inputs }: OnnxruntimeRunContext<'_>,
+ OnnxruntimeRunContext { sess, inputs }: OnnxruntimeRunContext<'_>,
) -> anyhow::Result> {
- // FIXME: 現状では`f32`のみ対応。実行時にsessionからdatatypeが取れるので、別の型の対応も
- // おそらく可能ではあるが、それが必要になるよりもortクレートへの引越しが先になると思われる
- // のでこのままにする。
-
- if !sess
- .outputs
- .iter()
- .all(|info| matches!(info.output_type, TensorElementDataType::Float))
- {
- unimplemented!(
- "currently only `ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT` is supported for output",
- );
- }
-
- let outputs = sess.run::(inputs.iter_mut().map(|t| &mut **t as &mut _).collect())?;
-
- Ok(outputs
- .iter()
- .map(|o| OutputTensor::Float32((*o).clone().into_owned()))
- .collect())
+ let outputs = sess.run(&*inputs)?;
+
+ (0..outputs.len())
+ .map(|i| {
+ let output = &outputs[i];
+
+ let ValueType::Tensor { ty, .. } = output.dtype()? else {
+ bail!(
+ "unexpected output. currently `ONNX_TYPE_TENSOR` and \
+ `ONNX_TYPE_SPARSETENSOR` is supported",
+ );
+ };
+
+ match ty {
+ TensorElementType::Float32 => {
+ let output = output.try_extract_tensor::()?;
+ Ok(OutputTensor::Float32(output.into_owned()))
+ }
+ _ => bail!("unexpected output tensor element data type"),
+ }
+ })
+ .collect()
}
}
+fn build_ort_env_once() -> ort::Result<()> {
+ static ONCE: once_cell::sync::OnceCell<()> = once_cell::sync::OnceCell::new();
+ ONCE.get_or_try_init(|| ort::init().with_name(env!("CARGO_PKG_NAME")).commit())?;
+ Ok(())
+}
+
pub(crate) struct OnnxruntimeRunContext<'sess> {
- sess: &'sess mut AssertSend>,
- inputs: Vec>,
+ sess: &'sess ort::Session,
+ inputs: Vec>,
}
impl OnnxruntimeRunContext<'_> {
fn push_input(
&mut self,
- input: Array,
- ) {
- self.inputs
- .push(Box::new(onnxruntime::session::NdArray::new(input)));
+ input: Array<
+ impl IntoTensorElementType + Debug + Clone + 'static,
+ impl Dimension + 'static,
+ >,
+ ) -> anyhow::Result<()> {
+ let input = ort::Value::from_array(input)?.into();
+ self.inputs.push(input);
+ Ok(())
}
}
-impl<'sess> From<&'sess mut AssertSend>>
- for OnnxruntimeRunContext<'sess>
-{
- fn from(sess: &'sess mut AssertSend>) -> Self {
+impl<'sess> From<&'sess mut ort::Session> for OnnxruntimeRunContext<'sess> {
+ fn from(sess: &'sess mut ort::Session) -> Self {
Self {
sess,
inputs: vec![],
@@ -214,40 +221,7 @@ impl PushInputTensor for OnnxruntimeRunContext<'_> {
[ push_int64 ] [ i64 ];
[ push_float32 ] [ f32 ];
)]
- fn method(&mut self, tensor: Array) {
- self.push_input(tensor);
- }
-}
-
-// FIXME: 以下のことをちゃんと確認した後、onnxruntime-rs側で`Session`が`Send`であると宣言する。
-// https://github.com/VOICEVOX/voicevox_core/issues/307#issuecomment-1276184614
-mod assert_send {
- use std::ops::{Deref, DerefMut};
-
- pub(crate) struct AssertSend(T);
-
- impl From>
- for AssertSend>
- {
- fn from(session: onnxruntime::session::Session<'static>) -> Self {
- Self(session)
- }
+ fn method(&mut self, tensor: Array) -> anyhow::Result<()> {
+ self.push_input(tensor)
}
-
- impl Deref for AssertSend {
- type Target = T;
-
- fn deref(&self) -> &Self::Target {
- &self.0
- }
- }
-
- impl DerefMut for AssertSend {
- fn deref_mut(&mut self) -> &mut Self::Target {
- &mut self.0
- }
- }
-
- // SAFETY: `Session` is probably "send"able.
- unsafe impl Send for AssertSend {}
}
diff --git a/crates/voicevox_core/src/infer/session_set.rs b/crates/voicevox_core/src/infer/session_set.rs
index 56d570f98..cdd179680 100644
--- a/crates/voicevox_core/src/infer/session_set.rs
+++ b/crates/voicevox_core/src/infer/session_set.rs
@@ -94,9 +94,8 @@ impl InferenceSessionCell
input: I,
) -> crate::Result<::Output> {
let inner = &mut self.inner.lock().unwrap();
- let ctx = input.make_run_context::(inner);
- R::run(ctx)
- .and_then(TryInto::try_into)
- .map_err(|e| ErrorRepr::InferenceFailed(e).into())
+ (|| R::run(input.make_run_context::(inner)?)?.try_into())()
+ .map_err(ErrorRepr::InferenceFailed)
+ .map_err(Into::into)
}
}
diff --git a/crates/voicevox_core_c_api/Cargo.toml b/crates/voicevox_core_c_api/Cargo.toml
index a6314f105..ad3a65fa7 100644
--- a/crates/voicevox_core_c_api/Cargo.toml
+++ b/crates/voicevox_core_c_api/Cargo.toml
@@ -13,6 +13,7 @@ harness = false
name = "e2e"
[features]
+cuda = ["voicevox_core/cuda"]
directml = ["voicevox_core/directml"]
[dependencies]
diff --git a/crates/voicevox_core_c_api/src/lib.rs b/crates/voicevox_core_c_api/src/lib.rs
index a5da9b6d3..fbb0bf6bf 100644
--- a/crates/voicevox_core_c_api/src/lib.rs
+++ b/crates/voicevox_core_c_api/src/lib.rs
@@ -59,7 +59,13 @@ fn init_logger_once() {
.with_env_filter(if env::var_os(EnvFilter::DEFAULT_ENV).is_some() {
EnvFilter::from_default_env()
} else {
- "error,voicevox_core=info,voicevox_core_c_api=info,onnxruntime=info".into()
+ pub const ORT_LOGGING_LEVEL: &str = if cfg!(debug_assertions) {
+ "info"
+ } else {
+ "warn"
+ };
+ format!("error,voicevox_core=info,voicevox_core_c_api=info,ort={ORT_LOGGING_LEVEL}")
+ .into()
})
.with_timer(local_time as fn(&mut Writer<'_>) -> _)
.with_ansi(ansi)
diff --git a/crates/voicevox_core_c_api/tests/e2e/assert_cdylib.rs b/crates/voicevox_core_c_api/tests/e2e/assert_cdylib.rs
index 1e4958eda..cfbec5c31 100644
--- a/crates/voicevox_core_c_api/tests/e2e/assert_cdylib.rs
+++ b/crates/voicevox_core_c_api/tests/e2e/assert_cdylib.rs
@@ -46,11 +46,7 @@ pub(crate) fn exec() -> anyhow::Result<()> {
// テスト対象が無いときに`cargo build`をスキップしたいが、判定部分がプライベート。
// そのためスキップするのはCLIオプションに`--ignored`か`--include-ignored`が無いときのみ
if args.ignored || args.include_ignored {
- let mut cmd = cmd!(env!("CARGO"), "build", "--release", "--lib");
- for (k, v) in C::BUILD_ENVS {
- cmd = cmd.env(k, v);
- }
- cmd.run()?;
+ cmd!(env!("CARGO"), "build", "--release", "--lib").run()?;
ensure!(
C::cdylib_path().exists(),
@@ -102,7 +98,6 @@ pub(crate) fn exec() -> anyhow::Result<()> {
pub(crate) trait TestContext {
const TARGET_DIR: &'static str;
const CDYLIB_NAME: &'static str;
- const BUILD_ENVS: &'static [(&'static str, &'static str)];
const RUNTIME_ENVS: &'static [(&'static str, &'static str)];
}
diff --git a/crates/voicevox_core_c_api/tests/e2e/main.rs b/crates/voicevox_core_c_api/tests/e2e/main.rs
index 91f5e06e9..43dc3a95e 100644
--- a/crates/voicevox_core_c_api/tests/e2e/main.rs
+++ b/crates/voicevox_core_c_api/tests/e2e/main.rs
@@ -24,16 +24,6 @@ fn main() -> anyhow::Result<()> {
impl assert_cdylib::TestContext for TestContext {
const TARGET_DIR: &'static str = "../../target";
const CDYLIB_NAME: &'static str = "voicevox_core";
- const BUILD_ENVS: &'static [(&'static str, &'static str)] = &[
- // 他の単体テストが動いているときにonnxruntime-sysの初回ビルドを行うと、Windows環境だと
- // `$ORT_OUT_DIR`のハックが問題を起こす。そのためこのハック自体を無効化する
- //
- // featuresの差分を出さないように`cargo build`することができればonnxruntime-sysの
- // ビルド自体がされないのだが、このバイナリから`cargo build`の状況を知るのは無理に近い
- ("ORT_OUT_DIR", ""),
- // DirectMLとCUDAは無効化
- ("ORT_USE_CUDA", "0"),
- ];
const RUNTIME_ENVS: &'static [(&'static str, &'static str)] =
&[("VV_MODELS_ROOT_DIR", VV_MODELS_ROOT_DIR)];
}
diff --git a/crates/voicevox_core_java_api/Cargo.toml b/crates/voicevox_core_java_api/Cargo.toml
index 887813685..06b2af618 100644
--- a/crates/voicevox_core_java_api/Cargo.toml
+++ b/crates/voicevox_core_java_api/Cargo.toml
@@ -8,6 +8,7 @@ publish.workspace = true
crate-type = ["cdylib"]
[features]
+cuda = ["voicevox_core/cuda"]
directml = ["voicevox_core/directml"]
[dependencies]
diff --git a/crates/voicevox_core_java_api/settings.gradle b/crates/voicevox_core_java_api/settings.gradle
index 20a5e2c6a..75f5810ac 100644
--- a/crates/voicevox_core_java_api/settings.gradle
+++ b/crates/voicevox_core_java_api/settings.gradle
@@ -40,5 +40,5 @@ gradle.ext {
gsonVersion = '2.10.1'
jakartaValidationVersion = '3.0.2'
jakartaAnnotationVersion = '2.1.1'
- onnxruntimeVersion = '1.14.0'
+ onnxruntimeVersion = '1.17.3'
}
diff --git a/crates/voicevox_core_java_api/src/logger.rs b/crates/voicevox_core_java_api/src/logger.rs
index 4800452ca..30545725e 100644
--- a/crates/voicevox_core_java_api/src/logger.rs
+++ b/crates/voicevox_core_java_api/src/logger.rs
@@ -10,10 +10,11 @@ extern "system" fn Java_jp_hiroshiba_voicevoxcore_Dll_00024LoggerInitializer_ini
android_logger::Config::default()
.with_tag("VoicevoxCore")
.with_filter(
- android_logger::FilterBuilder::new()
- .parse("error,voicevox_core=info,voicevox_core_java_api=info,onnxruntime=error")
- .build(),
- ),
+ android_logger::FilterBuilder::new()
+ // FIXME: ortも`warn`は出すべき
+ .parse("error,voicevox_core=info,voicevox_core_java_api=info,ort=error")
+ .build(),
+ ),
);
} else {
// TODO: Android以外でのログ出力を良い感じにする。(System.Loggerを使う?)
@@ -29,7 +30,8 @@ extern "system" fn Java_jp_hiroshiba_voicevoxcore_Dll_00024LoggerInitializer_ini
.with_env_filter(if env::var_os(EnvFilter::DEFAULT_ENV).is_some() {
EnvFilter::from_default_env()
} else {
- "error,voicevox_core=info,voicevox_core_c_api=info,onnxruntime=error".into()
+ // FIXME: `c_api`じゃないし、ortも`warn`は出すべき
+ "error,voicevox_core=info,voicevox_core_c_api=info,ort=error".into()
})
.with_timer(local_time as fn(&mut Writer<'_>) -> _)
.with_ansi(out().is_terminal() && env_allows_ansi())
diff --git a/crates/voicevox_core_macros/src/inference_domain.rs b/crates/voicevox_core_macros/src/inference_domain.rs
index 72bc4d18a..d24a20ab1 100644
--- a/crates/voicevox_core_macros/src/inference_domain.rs
+++ b/crates/voicevox_core_macros/src/inference_domain.rs
@@ -223,22 +223,28 @@ pub(crate) fn derive_inference_input_signature(
fn make_run_context(
self,
sess: &mut R::Session,
- ) -> R::RunContext<'_> {
+ ) -> ::anyhow::Result> {
let mut ctx = as ::std::convert::From<_>>::from(sess);
#(
- __ArrayExt::push_to_ctx(self.#field_names, &mut ctx);
+ __ArrayExt::push_to_ctx(self.#field_names, &mut ctx)?;
)*
- return ctx;
+ return ::std::result::Result::Ok(ctx);
trait __ArrayExt {
- fn push_to_ctx(self, ctx: &mut impl crate::infer::PushInputTensor);
+ fn push_to_ctx(
+ self,
+ ctx: &mut impl crate::infer::PushInputTensor,
+ ) -> ::anyhow::Result<()>;
}
impl __ArrayExt
for ::ndarray::Array
{
- fn push_to_ctx(self, ctx: &mut impl crate::infer::PushInputTensor) {
- A::push_tensor_to_ctx(self, ctx);
+ fn push_to_ctx(
+ self,
+ ctx: &mut impl crate::infer::PushInputTensor,
+ ) -> ::anyhow::Result<()> {
+ A::push_tensor_to_ctx(self, ctx)
}
}
}
diff --git a/crates/voicevox_core_python_api/Cargo.toml b/crates/voicevox_core_python_api/Cargo.toml
index be3ecbf27..5ccd1dc41 100644
--- a/crates/voicevox_core_python_api/Cargo.toml
+++ b/crates/voicevox_core_python_api/Cargo.toml
@@ -8,6 +8,7 @@ publish.workspace = true
crate-type = ["cdylib"]
[features]
+cuda = ["voicevox_core/cuda"]
directml = ["voicevox_core/directml"]
[dependencies]