From 5a644cad50457d22cdde183020abdfb5d00fd800 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Wed, 22 May 2024 09:34:59 +0900 Subject: [PATCH] =?UTF-8?q?onnxruntime-rs=E3=81=8B=E3=82=89ort=E3=81=AB?= =?UTF-8?q?=E4=B9=97=E3=82=8A=E6=8F=9B=E3=81=88=E3=82=8B=20(#725)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * onnxruntime-rsからortに乗り換える * `--features onnxruntime/disable-sys-build-script`を消す * ortをアップデート * ortをアップデート * `onnxruntimeVersion`をアップデート * libonnxruntimeのコピー処理を更新 * ortをアップデート * libonnxruntimeのコピー処理を更新 * ortをアップデート * ortをアップデート * `ort::ExecutionProvider::is_available`を使う * `todo!`を消す * ortをアップデート * ortにあったAPIを使う * ortをアップデート * `$ORT_OUT_DIR`を削除 * ortをアップデート * ログのフィルタを更新 * ortをアップデート * tracingのレベルでortのログを抑える * Minor refactor * ortをアップデート * Fix Cargo.lock * Gradleのlibonnxruntimeのバージョンを更新 * ort v2.0.0-rc.1ベースに切り替える * Gradleのlibonnxruntimeのバージョンを更新 * `with_execution_provider` → `register` * ort v2.0.0-rc.2ベースに切り替える * Gradleのlibonnxruntimeのバージョンを更新 * voicevox-ortを更新 * VOICEVOX/ort#2 に追従する --- .cargo/config.toml | 4 - .github/workflows/build_and_deploy.yml | 23 +- .github/workflows/test.yml | 14 +- Cargo.lock | 123 +++++--- Cargo.toml | 6 +- crates/voicevox_core/Cargo.toml | 5 +- crates/voicevox_core/src/infer.rs | 16 +- .../src/infer/runtimes/onnxruntime.rs | 284 ++++++++---------- crates/voicevox_core/src/infer/session_set.rs | 7 +- crates/voicevox_core_c_api/Cargo.toml | 1 + crates/voicevox_core_c_api/src/lib.rs | 8 +- .../tests/e2e/assert_cdylib.rs | 7 +- crates/voicevox_core_c_api/tests/e2e/main.rs | 10 - crates/voicevox_core_java_api/Cargo.toml | 1 + crates/voicevox_core_java_api/settings.gradle | 2 +- crates/voicevox_core_java_api/src/logger.rs | 12 +- .../src/inference_domain.rs | 18 +- crates/voicevox_core_python_api/Cargo.toml | 1 + 18 files changed, 269 insertions(+), 273 deletions(-) diff --git a/.cargo/config.toml b/.cargo/config.toml index 64fcc6aea..f015cd17a 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -4,10 +4,6 @@ xtask = "run -p xtask --" [env] CARGO_WORKSPACE_DIR = { value = "", relative = true } -# Windows環境でテストエラーになるのを防ぐために設定するworkaround -# https://github.com/VOICEVOX/onnxruntime-rs/issues/3#issuecomment-1207381367 -ORT_OUT_DIR = { value = "target/debug/deps", relative = true } - [target.aarch64-unknown-linux-gnu] linker = "aarch64-linux-gnu-gcc" diff --git a/.github/workflows/build_and_deploy.yml b/.github/workflows/build_and_deploy.yml index 83d895c77..01ddfc434 100644 --- a/.github/workflows/build_and_deploy.yml +++ b/.github/workflows/build_and_deploy.yml @@ -58,7 +58,6 @@ jobs: "target": "x86_64-pc-windows-msvc", "artifact_name": "windows-x64-cpu", "whl_local_version": "cpu", - "use_cuda": false, "can_skip_in_simple_test": true }, { @@ -67,16 +66,14 @@ jobs: "target": "x86_64-pc-windows-msvc", "artifact_name": "windows-x64-directml", "whl_local_version": "directml", - "use_cuda": false, "can_skip_in_simple_test": false }, { "os": "windows-2019", - "features": "", + "features": "cuda", "target": "x86_64-pc-windows-msvc", "artifact_name": "windows-x64-cuda", "whl_local_version": "cuda", - "use_cuda": true, "can_skip_in_simple_test": true }, { @@ -85,7 +82,6 @@ jobs: "target": "i686-pc-windows-msvc", "artifact_name": "windows-x86-cpu", "whl_local_version": "cpu", - "use_cuda": false, "can_skip_in_simple_test": true }, { @@ -94,16 +90,14 @@ jobs: "target": "x86_64-unknown-linux-gnu", "artifact_name": "linux-x64-cpu", "whl_local_version": "cpu", - "use_cuda": false, "can_skip_in_simple_test": true }, { "os": "ubuntu-20.04", - "features": "", + "features": "cuda", "target": "x86_64-unknown-linux-gnu", "artifact_name": "linux-x64-gpu", "whl_local_version": "cuda", - "use_cuda": true, "can_skip_in_simple_test": false }, { @@ -112,7 +106,6 @@ jobs: "target": "aarch64-unknown-linux-gnu", "artifact_name": "linux-arm64-cpu", "whl_local_version": "cpu", - "use_cuda": false, "can_skip_in_simple_test": true }, { @@ -120,7 +113,6 @@ jobs: "features": "", "target": "aarch64-linux-android", "artifact_name": "android-arm64-cpu", - "use_cuda": false, "can_skip_in_simple_test": true }, { @@ -128,7 +120,6 @@ jobs: "features": "", "target": "x86_64-linux-android", "artifact_name": "android-x86_64-cpu", - "use_cuda": false, "can_skip_in_simple_test": true }, { @@ -137,7 +128,6 @@ jobs: "target": "aarch64-apple-darwin", "artifact_name": "osx-arm64-cpu", "whl_local_version": "cpu", - "use_cuda": false, "can_skip_in_simple_test": false }, { @@ -146,7 +136,6 @@ jobs: "target": "x86_64-apple-darwin", "artifact_name": "osx-x64-cpu", "whl_local_version": "cpu", - "use_cuda": false, "can_skip_in_simple_test": true }, { @@ -154,7 +143,6 @@ jobs: "features": "", "target": "aarch64-apple-ios", "artifact_name": "ios-arm64-cpu", - "use_cuda": false, "can_skip_in_simple_test": true }, { @@ -162,7 +150,6 @@ jobs: "features": "", "target": "aarch64-apple-ios-sim", "artifact_name": "ios-arm64-cpu-sim", - "use_cuda": false, "can_skip_in_simple_test": true }, { @@ -170,7 +157,6 @@ jobs: "features": "", "target": "x86_64-apple-ios", "artifact_name": "ios-x64-cpu", - "use_cuda": false, "can_skip_in_simple_test": true } ]' @@ -268,7 +254,6 @@ jobs: fi env: RUSTFLAGS: -C panic=abort - ORT_USE_CUDA: ${{ matrix.use_cuda }} - name: build voicevox_core_python_api if: matrix.whl_local_version id: build-voicevox-core-python-api @@ -286,8 +271,6 @@ jobs: build > /dev/null 2>&1 fi echo "whl=$(find ./target/wheels -type f)" >> "$GITHUB_OUTPUT" - env: - ORT_USE_CUDA: ${{ matrix.use_cuda }} - name: build voicevox_core_java_api if: contains(matrix.target, 'android') run: | @@ -305,7 +288,7 @@ jobs: cp -v crates/voicevox_core_c_api/include/voicevox_core.h "artifact/${{ env.ASSET_NAME }}" cp -v target/${{ matrix.target }}/release/*voicevox_core.{dll,so,dylib} "artifact/${{ env.ASSET_NAME }}" || true cp -v target/${{ matrix.target }}/release/voicevox_core.dll.lib "artifact/${{ env.ASSET_NAME }}/voicevox_core.lib" || true - cp -v -n target/${{ matrix.target }}/release/build/onnxruntime-sys-*/out/onnxruntime_*/onnxruntime-*/lib/*.{dll,so.*,so,dylib} "artifact/${{ env.ASSET_NAME }}" || true + cp -v -n target/${{ matrix.target }}/release/{,lib}onnxruntime*.{dll,so.*,so,dylib} "artifact/${{ env.ASSET_NAME }}" || true # libonnxruntimeについてはバージョン付のshared libraryを使用するためバージョンがついてないものを削除する rm -f artifact/${{ env.ASSET_NAME }}/libonnxruntime.{so,dylib} cp -v README.md "artifact/${{ env.ASSET_NAME }}/README.txt" diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0496eebdb..d0639d045 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -72,8 +72,8 @@ jobs: with: python-version: "3.8" - uses: Swatinem/rust-cache@v2 - - run: cargo clippy -vv --all-features --features onnxruntime/disable-sys-build-script --tests -- -D clippy::all -D warnings --no-deps - - run: cargo clippy -vv --all-features --features onnxruntime/disable-sys-build-script -- -D clippy::all -D warnings --no-deps + - run: cargo clippy -vv --all-features --tests -- -D clippy::all -D warnings --no-deps + - run: cargo clippy -vv --all-features -- -D clippy::all -D warnings --no-deps - run: cargo fmt -- --check rust-unit-test: @@ -199,8 +199,8 @@ jobs: mkdir -p example/cpp/unix/voicevox_core/ cp -v crates/voicevox_core_c_api/include/voicevox_core.h example/cpp/unix/voicevox_core/ cp -v target/debug/libvoicevox_core.{so,dylib} example/cpp/unix/voicevox_core/ || true - cp -v target/debug/build/onnxruntime-sys-*/out/onnxruntime_*/onnxruntime-*/lib/libonnxruntime.so.* example/cpp/unix/voicevox_core/ || true - cp -v target/debug/build/onnxruntime-sys-*/out/onnxruntime_*/onnxruntime-*/lib/libonnxruntime.*.dylib example/cpp/unix/voicevox_core/ || true + cp -v target/debug/libonnxruntime.so.* example/cpp/unix/voicevox_core/ || true + cp -v target/debug/libonnxruntime.*.dylib example/cpp/unix/voicevox_core/ || true - if: startsWith(matrix.os, 'mac') uses: jwlawson/actions-setup-cmake@v1.13 @@ -281,9 +281,9 @@ jobs: - run: poetry run maturin develop --locked - name: 必要なDLLをコピーしてpytestを実行 run: | - cp -v ../../target/debug/build/onnxruntime-sys-*/out/onnxruntime_*/onnxruntime-*/lib/onnxruntime.dll . || true - cp -v ../../target/debug/build/onnxruntime-sys-*/out/onnxruntime_*/onnxruntime-*/lib/libonnxruntime.so.* . || true - cp -v ../../target/debug/build/onnxruntime-sys-*/out/onnxruntime_*/onnxruntime-*/lib/libonnxruntime.*.dylib . || true + cp -v ../../target/debug/onnxruntime.dll . || true + cp -v ../../target/debug/libonnxruntime.so.* . || true + cp -v ../../target/debug/libonnxruntime.*.dylib . || true poetry run pytest - name: Exampleを実行 diff --git a/Cargo.lock b/Cargo.lock index 4015b7e36..0f9a78469 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -426,12 +426,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "chunked_transfer" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fff857943da45f546682664a79488be82e69e43c1a7a2307679ab9afb3a66d2e" - [[package]] name = "cipher" version = "0.3.0" @@ -684,6 +678,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + [[package]] name = "crypto-common" version = "0.1.6" @@ -1278,6 +1278,16 @@ dependencies = [ "tracing", ] +[[package]] +name = "half" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc52e53916c08643f1b56ec082790d1e86a32e58dc5268f897f313fbae7b4872" +dependencies = [ + "cfg-if", + "crunchy", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -1410,7 +1420,7 @@ checksum = "1788965e61b367cd03a62950836d5cd41560c3577d90e40e0819373194d1661c" dependencies = [ "http", "hyper", - "rustls", + "rustls 0.20.6", "tokio", "tokio-rustls", ] @@ -1993,30 +2003,6 @@ version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" -[[package]] -name = "onnxruntime" -version = "0.1.0" -source = "git+https://github.com/VOICEVOX/onnxruntime-rs.git?rev=ebb9dcb9b26ee681889b52b6db3b4f642b04a250#ebb9dcb9b26ee681889b52b6db3b4f642b04a250" -dependencies = [ - "lazy_static", - "ndarray", - "onnxruntime-sys", - "thiserror", - "tracing", -] - -[[package]] -name = "onnxruntime-sys" -version = "0.0.25" -source = "git+https://github.com/VOICEVOX/onnxruntime-rs.git?rev=ebb9dcb9b26ee681889b52b6db3b4f642b04a250#ebb9dcb9b26ee681889b52b6db3b4f642b04a250" -dependencies = [ - "flate2", - "once_cell", - "tar", - "ureq", - "zip", -] - [[package]] name = "opaque-debug" version = "0.3.0" @@ -2598,7 +2584,7 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "rustls", + "rustls 0.20.6", "rustls-pemfile", "serde", "serde_json", @@ -2611,7 +2597,7 @@ dependencies = [ "wasm-bindgen", "wasm-bindgen-futures", "web-sys", - "webpki-roots", + "webpki-roots 0.22.5", "winreg", ] @@ -2728,6 +2714,18 @@ dependencies = [ "webpki", ] +[[package]] +name = "rustls" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd8d6c9f025a446bc4d18ad9632e69aec8f287aa84499ee335599fabd20c3fd8" +dependencies = [ + "log", + "ring", + "rustls-webpki", + "sct", +] + [[package]] name = "rustls-pemfile" version = "1.0.2" @@ -2737,6 +2735,16 @@ dependencies = [ "base64 0.21.0", ] +[[package]] +name = "rustls-webpki" +version = "0.101.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c7d5dece342910d9ba34d259310cae3e0154b873b35408b787b59bce53d34fe" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.11" @@ -3262,7 +3270,7 @@ version = "0.23.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c43ee83903113e03984cb9e5cebe6c04a5116269e900e3ddba8f068a62adda59" dependencies = [ - "rustls", + "rustls 0.20.6", "tokio", "webpki", ] @@ -3480,19 +3488,17 @@ checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" [[package]] name = "ureq" -version = "2.5.0" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b97acb4c28a254fd7a4aeec976c46a7fa404eac4d7c134b30c75144846d7cb8f" +checksum = "f5ccd538d4a604753ebc2f17cd9946e89b77bf87f6a8e2309667c6f2e87855e3" dependencies = [ - "base64 0.13.0", - "chunked_transfer", - "flate2", + "base64 0.21.0", "log", "once_cell", - "rustls", + "rustls 0.21.7", + "rustls-webpki", "url", - "webpki", - "webpki-roots", + "webpki-roots 0.25.4", ] [[package]] @@ -3535,6 +3541,31 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "voicevox-ort" +version = "2.0.0-rc.2" +source = "git+https://github.com/VOICEVOX/ort.git?rev=a2d6ae22327869e896bf4c16828734d09516d2d9#a2d6ae22327869e896bf4c16828734d09516d2d9" +dependencies = [ + "half", + "js-sys", + "ndarray", + "thiserror", + "tracing", + "voicevox-ort-sys", + "web-sys", +] + +[[package]] +name = "voicevox-ort-sys" +version = "2.0.0-rc.2" +source = "git+https://github.com/VOICEVOX/ort.git?rev=a2d6ae22327869e896bf4c16828734d09516d2d9#a2d6ae22327869e896bf4c16828734d09516d2d9" +dependencies = [ + "flate2", + "sha2", + "tar", + "ureq", +] + [[package]] name = "voicevox_core" version = "0.0.0" @@ -3559,7 +3590,6 @@ dependencies = [ "nanoid", "ndarray", "once_cell", - "onnxruntime", "open_jtalk", "ouroboros", "pretty_assertions", @@ -3578,6 +3608,7 @@ dependencies = [ "tokio", "tracing", "uuid", + "voicevox-ort", "voicevox_core_macros", "windows", "zip", @@ -3797,6 +3828,12 @@ dependencies = [ "webpki", ] +[[package]] +name = "webpki-roots" +version = "0.25.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" + [[package]] name = "which" version = "4.3.0" diff --git a/Cargo.toml b/Cargo.toml index 69d924a16..7102127d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -86,9 +86,9 @@ voicevox_core = { path = "crates/voicevox_core" } windows = "0.43.0" zip = "0.6.3" -[workspace.dependencies.onnxruntime] -git = "https://github.com/VOICEVOX/onnxruntime-rs.git" -rev = "ebb9dcb9b26ee681889b52b6db3b4f642b04a250" +[workspace.dependencies.voicevox-ort] +git = "https://github.com/VOICEVOX/ort.git" +rev = "a2d6ae22327869e896bf4c16828734d09516d2d9" [workspace.dependencies.open_jtalk] git = "https://github.com/VOICEVOX/open_jtalk-rs.git" diff --git a/crates/voicevox_core/Cargo.toml b/crates/voicevox_core/Cargo.toml index 5c8dd440d..731862356 100644 --- a/crates/voicevox_core/Cargo.toml +++ b/crates/voicevox_core/Cargo.toml @@ -6,7 +6,8 @@ publish.workspace = true [features] default = [] -directml = ["onnxruntime/directml"] +cuda = ["voicevox-ort/cuda"] +directml = ["voicevox-ort/directml"] [dependencies] anyhow.workspace = true @@ -27,7 +28,6 @@ jlabel.workspace = true nanoid.workspace = true ndarray.workspace = true once_cell.workspace = true -onnxruntime.workspace = true open_jtalk.workspace = true ouroboros.workspace = true rayon.workspace = true @@ -43,6 +43,7 @@ tokio = { workspace = true, features = ["rt"] } # FIXME: feature-gateする tracing.workspace = true uuid = { workspace = true, features = ["v4", "serde"] } voicevox_core_macros = { path = "../voicevox_core_macros" } +voicevox-ort = { workspace = true, features = ["ndarray", "download-binaries"] } zip.workspace = true [dev-dependencies] diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index fc8954e7d..c2cad1d7d 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -79,16 +79,20 @@ pub(crate) trait InferenceSignature: Sized + Send + 'static { pub(crate) trait InferenceInputSignature: Send + 'static { type Signature: InferenceSignature; const PARAM_INFOS: &'static [ParamInfo]; - fn make_run_context(self, sess: &mut R::Session) -> R::RunContext<'_>; + fn make_run_context( + self, + sess: &mut R::Session, + ) -> anyhow::Result>; } pub(crate) trait InputScalar: Sized { const KIND: InputScalarKind; + // TODO: `Array`ではなく`ArrayView`を取ることができるかもしれない fn push_tensor_to_ctx( tensor: Array, visitor: &mut impl PushInputTensor, - ); + ) -> anyhow::Result<()>; } #[duplicate_item( @@ -102,8 +106,8 @@ impl InputScalar for T { fn push_tensor_to_ctx( tensor: Array, ctx: &mut impl PushInputTensor, - ) { - ctx.push(tensor); + ) -> anyhow::Result<()> { + ctx.push(tensor) } } @@ -117,8 +121,8 @@ pub(crate) enum InputScalarKind { } pub(crate) trait PushInputTensor { - fn push_int64(&mut self, tensor: Array); - fn push_float32(&mut self, tensor: Array); + fn push_int64(&mut self, tensor: Array) -> anyhow::Result<()>; + fn push_float32(&mut self, tensor: Array) -> anyhow::Result<()>; } /// 推論操作の出力シグネチャ。 diff --git a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs index 0556b6d51..f8f376837 100644 --- a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs +++ b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs @@ -1,18 +1,15 @@ use std::{fmt::Debug, vec}; -use anyhow::anyhow; +use anyhow::{anyhow, bail, ensure}; use duplicate::duplicate_item; use ndarray::{Array, Dimension}; -use once_cell::sync::Lazy; -use onnxruntime::{ - environment::Environment, GraphOptimizationLevel, LoggingLevel, TensorElementDataType, - TypeToTensorElementDataType, +use ort::{ + CPUExecutionProvider, CUDAExecutionProvider, DirectMLExecutionProvider, ExecutionProvider as _, + GraphOptimizationLevel, IntoTensorElementType, TensorElementType, ValueType, }; use crate::{devices::SupportedDevices, error::ErrorRepr}; -use self::assert_send::AssertSend; - use super::super::{ DecryptModelError, InferenceRuntime, InferenceSessionOptions, InputScalarKind, OutputScalarKind, OutputTensor, ParamInfo, PushInputTensor, @@ -22,29 +19,28 @@ use super::super::{ pub(crate) enum Onnxruntime {} impl InferenceRuntime for Onnxruntime { - type Session = AssertSend>; + type Session = ort::Session; type RunContext<'a> = OnnxruntimeRunContext<'a>; fn supported_devices() -> crate::Result { - let mut cuda_support = false; - let mut dml_support = false; - for provider in onnxruntime::session::get_available_providers() - .map_err(Into::into) - .map_err(ErrorRepr::GetSupportedDevices)? - .iter() - { - match provider.as_str() { - "CUDAExecutionProvider" => cuda_support = true, - "DmlExecutionProvider" => dml_support = true, - _ => {} - } - } + // TODO: `InferenceRuntime::init`と`InitInferenceRuntimeError`を作る + build_ort_env_once().unwrap(); + + (|| { + let cpu = CPUExecutionProvider::default().is_available()?; + let cuda = CUDAExecutionProvider::default().is_available()?; + let dml = DirectMLExecutionProvider::default().is_available()?; - Ok(SupportedDevices { - cpu: true, - cuda: cuda_support, - dml: dml_support, - }) + ensure!(cpu, "missing `CPUExecutionProvider`"); + + Ok(SupportedDevices { + cpu: true, + cuda, + dml, + }) + })() + .map_err(ErrorRepr::GetSupportedDevices) + .map_err(Into::into) } fn new_session( @@ -55,48 +51,52 @@ impl InferenceRuntime for Onnxruntime { Vec>, Vec>, )> { - let mut builder = ENVIRONMENT - .new_session_builder()? - .with_optimization_level(GraphOptimizationLevel::Basic)? - .with_intra_op_num_threads(options.cpu_num_threads.into())? - .with_inter_op_num_threads(options.cpu_num_threads.into())?; - - if options.use_gpu { - #[cfg(feature = "directml")] - { - use onnxruntime::ExecutionMode; - - builder = builder - .with_disable_mem_pattern()? - .with_execution_mode(ExecutionMode::ORT_SEQUENTIAL)? - .with_append_execution_provider_directml(0)?; - } - - #[cfg(not(feature = "directml"))] - { - builder = builder.with_append_execution_provider_cuda(Default::default())?; - } + // TODO: `InferenceRuntime::init`と`InitInferenceRuntimeError`を作る + build_ort_env_once().unwrap(); + + let mut builder = ort::Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level1)? + .with_intra_threads(options.cpu_num_threads.into())?; + + if options.use_gpu && cfg!(feature = "directml") { + builder = builder + .with_parallel_execution(false)? + .with_memory_pattern(false)?; + DirectMLExecutionProvider::default().register(&builder)?; + } else if options.use_gpu && cfg!(feature = "cuda") { + CUDAExecutionProvider::default().register(&builder)?; } let model = model()?; - let sess = AssertSend::from(builder.with_model_from_memory(model)?); + let sess = builder.commit_from_memory(&{ model })?; let input_param_infos = sess .inputs .iter() .map(|info| { - let dt = match info.input_type { - TensorElementDataType::Float => Ok(InputScalarKind::Float32), - TensorElementDataType::Uint8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8"), - TensorElementDataType::Int8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8"), - TensorElementDataType::Uint16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16"), - TensorElementDataType::Int16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16"), - TensorElementDataType::Int32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32"), - TensorElementDataType::Int64 => Ok(InputScalarKind::Int64), - TensorElementDataType::String => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING"), - TensorElementDataType::Double => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE"), - TensorElementDataType::Uint32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32"), - TensorElementDataType::Uint64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64"), + let ValueType::Tensor { ty, .. } = info.input_type else { + bail!( + "unexpected input value type for `{}`. currently `ONNX_TYPE_TENSOR` and \ + `ONNX_TYPE_SPARSETENSOR` is supported", + info.name, + ); + }; + + let dt = match ty { + TensorElementType::Float32 => Ok(InputScalarKind::Float32), + TensorElementType::Uint8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8"), + TensorElementType::Int8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8"), + TensorElementType::Uint16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16"), + TensorElementType::Int16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16"), + TensorElementType::Int32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32"), + TensorElementType::Int64 => Ok(InputScalarKind::Int64), + TensorElementType::String => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING"), + TensorElementType::Bfloat16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16"), + TensorElementType::Float16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16"), + TensorElementType::Float64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE"), + TensorElementType::Uint32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32"), + TensorElementType::Uint64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64"), + TensorElementType::Bool => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL"), } .map_err(|actual| { anyhow!("unsupported input datatype `{actual}` for `{}`", info.name) @@ -105,7 +105,7 @@ impl InferenceRuntime for Onnxruntime { Ok(ParamInfo { name: info.name.clone().into(), dt, - ndim: Some(info.dimensions.len()), + ndim: info.input_type.tensor_dimensions().map(Vec::len), }) }) .collect::>()?; @@ -114,18 +114,29 @@ impl InferenceRuntime for Onnxruntime { .outputs .iter() .map(|info| { - let dt = match info.output_type { - TensorElementDataType::Float => Ok(OutputScalarKind::Float32), - TensorElementDataType::Uint8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8"), - TensorElementDataType::Int8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8"), - TensorElementDataType::Uint16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16"), - TensorElementDataType::Int16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16"), - TensorElementDataType::Int32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32"), - TensorElementDataType::Int64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64"), - TensorElementDataType::String => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING"), - TensorElementDataType::Double => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE"), - TensorElementDataType::Uint32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32"), - TensorElementDataType::Uint64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64"), + let ValueType::Tensor { ty, .. } = info.output_type else { + bail!( + "unexpected output value type for `{}`. currently `ONNX_TYPE_TENSOR` and \ + `ONNX_TYPE_SPARSETENSOR` is supported", + info.name, + ); + }; + + let dt = match ty { + TensorElementType::Float32 => Ok(OutputScalarKind::Float32), + TensorElementType::Uint8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8"), + TensorElementType::Int8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8"), + TensorElementType::Uint16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16"), + TensorElementType::Int16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16"), + TensorElementType::Int32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32"), + TensorElementType::Int64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64"), + TensorElementType::String => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING"), + TensorElementType::Bfloat16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16"), + TensorElementType::Float16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16"), + TensorElementType::Float64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE"), + TensorElementType::Uint32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32"), + TensorElementType::Uint64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64"), + TensorElementType::Bool => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL"), } .map_err(|actual| { anyhow!("unsupported output datatype `{actual}` for `{}`", info.name) @@ -134,73 +145,69 @@ impl InferenceRuntime for Onnxruntime { Ok(ParamInfo { name: info.name.clone().into(), dt, - ndim: Some(info.dimensions.len()), + ndim: info.output_type.tensor_dimensions().map(|d| d.len()), }) }) .collect::>()?; - return Ok((sess, input_param_infos, output_param_infos)); - - static ENVIRONMENT: Lazy = Lazy::new(|| { - Environment::builder() - .with_name(env!("CARGO_PKG_NAME")) - .with_log_level(LOGGING_LEVEL) - .build() - .unwrap() - }); - - const LOGGING_LEVEL: LoggingLevel = if cfg!(debug_assertions) { - LoggingLevel::Verbose - } else { - LoggingLevel::Warning - }; + Ok((sess, input_param_infos, output_param_infos)) } fn run( - OnnxruntimeRunContext { sess, mut inputs }: OnnxruntimeRunContext<'_>, + OnnxruntimeRunContext { sess, inputs }: OnnxruntimeRunContext<'_>, ) -> anyhow::Result> { - // FIXME: 現状では`f32`のみ対応。実行時にsessionからdatatypeが取れるので、別の型の対応も - // おそらく可能ではあるが、それが必要になるよりもortクレートへの引越しが先になると思われる - // のでこのままにする。 - - if !sess - .outputs - .iter() - .all(|info| matches!(info.output_type, TensorElementDataType::Float)) - { - unimplemented!( - "currently only `ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT` is supported for output", - ); - } - - let outputs = sess.run::(inputs.iter_mut().map(|t| &mut **t as &mut _).collect())?; - - Ok(outputs - .iter() - .map(|o| OutputTensor::Float32((*o).clone().into_owned())) - .collect()) + let outputs = sess.run(&*inputs)?; + + (0..outputs.len()) + .map(|i| { + let output = &outputs[i]; + + let ValueType::Tensor { ty, .. } = output.dtype()? else { + bail!( + "unexpected output. currently `ONNX_TYPE_TENSOR` and \ + `ONNX_TYPE_SPARSETENSOR` is supported", + ); + }; + + match ty { + TensorElementType::Float32 => { + let output = output.try_extract_tensor::()?; + Ok(OutputTensor::Float32(output.into_owned())) + } + _ => bail!("unexpected output tensor element data type"), + } + }) + .collect() } } +fn build_ort_env_once() -> ort::Result<()> { + static ONCE: once_cell::sync::OnceCell<()> = once_cell::sync::OnceCell::new(); + ONCE.get_or_try_init(|| ort::init().with_name(env!("CARGO_PKG_NAME")).commit())?; + Ok(()) +} + pub(crate) struct OnnxruntimeRunContext<'sess> { - sess: &'sess mut AssertSend>, - inputs: Vec>, + sess: &'sess ort::Session, + inputs: Vec>, } impl OnnxruntimeRunContext<'_> { fn push_input( &mut self, - input: Array, - ) { - self.inputs - .push(Box::new(onnxruntime::session::NdArray::new(input))); + input: Array< + impl IntoTensorElementType + Debug + Clone + 'static, + impl Dimension + 'static, + >, + ) -> anyhow::Result<()> { + let input = ort::Value::from_array(input)?.into(); + self.inputs.push(input); + Ok(()) } } -impl<'sess> From<&'sess mut AssertSend>> - for OnnxruntimeRunContext<'sess> -{ - fn from(sess: &'sess mut AssertSend>) -> Self { +impl<'sess> From<&'sess mut ort::Session> for OnnxruntimeRunContext<'sess> { + fn from(sess: &'sess mut ort::Session) -> Self { Self { sess, inputs: vec![], @@ -214,40 +221,7 @@ impl PushInputTensor for OnnxruntimeRunContext<'_> { [ push_int64 ] [ i64 ]; [ push_float32 ] [ f32 ]; )] - fn method(&mut self, tensor: Array) { - self.push_input(tensor); - } -} - -// FIXME: 以下のことをちゃんと確認した後、onnxruntime-rs側で`Session`が`Send`であると宣言する。 -// https://github.com/VOICEVOX/voicevox_core/issues/307#issuecomment-1276184614 -mod assert_send { - use std::ops::{Deref, DerefMut}; - - pub(crate) struct AssertSend(T); - - impl From> - for AssertSend> - { - fn from(session: onnxruntime::session::Session<'static>) -> Self { - Self(session) - } + fn method(&mut self, tensor: Array) -> anyhow::Result<()> { + self.push_input(tensor) } - - impl Deref for AssertSend { - type Target = T; - - fn deref(&self) -> &Self::Target { - &self.0 - } - } - - impl DerefMut for AssertSend { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } - } - - // SAFETY: `Session` is probably "send"able. - unsafe impl Send for AssertSend {} } diff --git a/crates/voicevox_core/src/infer/session_set.rs b/crates/voicevox_core/src/infer/session_set.rs index 56d570f98..cdd179680 100644 --- a/crates/voicevox_core/src/infer/session_set.rs +++ b/crates/voicevox_core/src/infer/session_set.rs @@ -94,9 +94,8 @@ impl InferenceSessionCell input: I, ) -> crate::Result<::Output> { let inner = &mut self.inner.lock().unwrap(); - let ctx = input.make_run_context::(inner); - R::run(ctx) - .and_then(TryInto::try_into) - .map_err(|e| ErrorRepr::InferenceFailed(e).into()) + (|| R::run(input.make_run_context::(inner)?)?.try_into())() + .map_err(ErrorRepr::InferenceFailed) + .map_err(Into::into) } } diff --git a/crates/voicevox_core_c_api/Cargo.toml b/crates/voicevox_core_c_api/Cargo.toml index a6314f105..ad3a65fa7 100644 --- a/crates/voicevox_core_c_api/Cargo.toml +++ b/crates/voicevox_core_c_api/Cargo.toml @@ -13,6 +13,7 @@ harness = false name = "e2e" [features] +cuda = ["voicevox_core/cuda"] directml = ["voicevox_core/directml"] [dependencies] diff --git a/crates/voicevox_core_c_api/src/lib.rs b/crates/voicevox_core_c_api/src/lib.rs index a5da9b6d3..fbb0bf6bf 100644 --- a/crates/voicevox_core_c_api/src/lib.rs +++ b/crates/voicevox_core_c_api/src/lib.rs @@ -59,7 +59,13 @@ fn init_logger_once() { .with_env_filter(if env::var_os(EnvFilter::DEFAULT_ENV).is_some() { EnvFilter::from_default_env() } else { - "error,voicevox_core=info,voicevox_core_c_api=info,onnxruntime=info".into() + pub const ORT_LOGGING_LEVEL: &str = if cfg!(debug_assertions) { + "info" + } else { + "warn" + }; + format!("error,voicevox_core=info,voicevox_core_c_api=info,ort={ORT_LOGGING_LEVEL}") + .into() }) .with_timer(local_time as fn(&mut Writer<'_>) -> _) .with_ansi(ansi) diff --git a/crates/voicevox_core_c_api/tests/e2e/assert_cdylib.rs b/crates/voicevox_core_c_api/tests/e2e/assert_cdylib.rs index 1e4958eda..cfbec5c31 100644 --- a/crates/voicevox_core_c_api/tests/e2e/assert_cdylib.rs +++ b/crates/voicevox_core_c_api/tests/e2e/assert_cdylib.rs @@ -46,11 +46,7 @@ pub(crate) fn exec() -> anyhow::Result<()> { // テスト対象が無いときに`cargo build`をスキップしたいが、判定部分がプライベート。 // そのためスキップするのはCLIオプションに`--ignored`か`--include-ignored`が無いときのみ if args.ignored || args.include_ignored { - let mut cmd = cmd!(env!("CARGO"), "build", "--release", "--lib"); - for (k, v) in C::BUILD_ENVS { - cmd = cmd.env(k, v); - } - cmd.run()?; + cmd!(env!("CARGO"), "build", "--release", "--lib").run()?; ensure!( C::cdylib_path().exists(), @@ -102,7 +98,6 @@ pub(crate) fn exec() -> anyhow::Result<()> { pub(crate) trait TestContext { const TARGET_DIR: &'static str; const CDYLIB_NAME: &'static str; - const BUILD_ENVS: &'static [(&'static str, &'static str)]; const RUNTIME_ENVS: &'static [(&'static str, &'static str)]; } diff --git a/crates/voicevox_core_c_api/tests/e2e/main.rs b/crates/voicevox_core_c_api/tests/e2e/main.rs index 91f5e06e9..43dc3a95e 100644 --- a/crates/voicevox_core_c_api/tests/e2e/main.rs +++ b/crates/voicevox_core_c_api/tests/e2e/main.rs @@ -24,16 +24,6 @@ fn main() -> anyhow::Result<()> { impl assert_cdylib::TestContext for TestContext { const TARGET_DIR: &'static str = "../../target"; const CDYLIB_NAME: &'static str = "voicevox_core"; - const BUILD_ENVS: &'static [(&'static str, &'static str)] = &[ - // 他の単体テストが動いているときにonnxruntime-sysの初回ビルドを行うと、Windows環境だと - // `$ORT_OUT_DIR`のハックが問題を起こす。そのためこのハック自体を無効化する - // - // featuresの差分を出さないように`cargo build`することができればonnxruntime-sysの - // ビルド自体がされないのだが、このバイナリから`cargo build`の状況を知るのは無理に近い - ("ORT_OUT_DIR", ""), - // DirectMLとCUDAは無効化 - ("ORT_USE_CUDA", "0"), - ]; const RUNTIME_ENVS: &'static [(&'static str, &'static str)] = &[("VV_MODELS_ROOT_DIR", VV_MODELS_ROOT_DIR)]; } diff --git a/crates/voicevox_core_java_api/Cargo.toml b/crates/voicevox_core_java_api/Cargo.toml index 887813685..06b2af618 100644 --- a/crates/voicevox_core_java_api/Cargo.toml +++ b/crates/voicevox_core_java_api/Cargo.toml @@ -8,6 +8,7 @@ publish.workspace = true crate-type = ["cdylib"] [features] +cuda = ["voicevox_core/cuda"] directml = ["voicevox_core/directml"] [dependencies] diff --git a/crates/voicevox_core_java_api/settings.gradle b/crates/voicevox_core_java_api/settings.gradle index 20a5e2c6a..75f5810ac 100644 --- a/crates/voicevox_core_java_api/settings.gradle +++ b/crates/voicevox_core_java_api/settings.gradle @@ -40,5 +40,5 @@ gradle.ext { gsonVersion = '2.10.1' jakartaValidationVersion = '3.0.2' jakartaAnnotationVersion = '2.1.1' - onnxruntimeVersion = '1.14.0' + onnxruntimeVersion = '1.17.3' } diff --git a/crates/voicevox_core_java_api/src/logger.rs b/crates/voicevox_core_java_api/src/logger.rs index 4800452ca..30545725e 100644 --- a/crates/voicevox_core_java_api/src/logger.rs +++ b/crates/voicevox_core_java_api/src/logger.rs @@ -10,10 +10,11 @@ extern "system" fn Java_jp_hiroshiba_voicevoxcore_Dll_00024LoggerInitializer_ini android_logger::Config::default() .with_tag("VoicevoxCore") .with_filter( - android_logger::FilterBuilder::new() - .parse("error,voicevox_core=info,voicevox_core_java_api=info,onnxruntime=error") - .build(), - ), + android_logger::FilterBuilder::new() + // FIXME: ortも`warn`は出すべき + .parse("error,voicevox_core=info,voicevox_core_java_api=info,ort=error") + .build(), + ), ); } else { // TODO: Android以外でのログ出力を良い感じにする。(System.Loggerを使う?) @@ -29,7 +30,8 @@ extern "system" fn Java_jp_hiroshiba_voicevoxcore_Dll_00024LoggerInitializer_ini .with_env_filter(if env::var_os(EnvFilter::DEFAULT_ENV).is_some() { EnvFilter::from_default_env() } else { - "error,voicevox_core=info,voicevox_core_c_api=info,onnxruntime=error".into() + // FIXME: `c_api`じゃないし、ortも`warn`は出すべき + "error,voicevox_core=info,voicevox_core_c_api=info,ort=error".into() }) .with_timer(local_time as fn(&mut Writer<'_>) -> _) .with_ansi(out().is_terminal() && env_allows_ansi()) diff --git a/crates/voicevox_core_macros/src/inference_domain.rs b/crates/voicevox_core_macros/src/inference_domain.rs index 72bc4d18a..d24a20ab1 100644 --- a/crates/voicevox_core_macros/src/inference_domain.rs +++ b/crates/voicevox_core_macros/src/inference_domain.rs @@ -223,22 +223,28 @@ pub(crate) fn derive_inference_input_signature( fn make_run_context( self, sess: &mut R::Session, - ) -> R::RunContext<'_> { + ) -> ::anyhow::Result> { let mut ctx = as ::std::convert::From<_>>::from(sess); #( - __ArrayExt::push_to_ctx(self.#field_names, &mut ctx); + __ArrayExt::push_to_ctx(self.#field_names, &mut ctx)?; )* - return ctx; + return ::std::result::Result::Ok(ctx); trait __ArrayExt { - fn push_to_ctx(self, ctx: &mut impl crate::infer::PushInputTensor); + fn push_to_ctx( + self, + ctx: &mut impl crate::infer::PushInputTensor, + ) -> ::anyhow::Result<()>; } impl __ArrayExt for ::ndarray::Array { - fn push_to_ctx(self, ctx: &mut impl crate::infer::PushInputTensor) { - A::push_tensor_to_ctx(self, ctx); + fn push_to_ctx( + self, + ctx: &mut impl crate::infer::PushInputTensor, + ) -> ::anyhow::Result<()> { + A::push_tensor_to_ctx(self, ctx) } } } diff --git a/crates/voicevox_core_python_api/Cargo.toml b/crates/voicevox_core_python_api/Cargo.toml index be3ecbf27..5ccd1dc41 100644 --- a/crates/voicevox_core_python_api/Cargo.toml +++ b/crates/voicevox_core_python_api/Cargo.toml @@ -8,6 +8,7 @@ publish.workspace = true crate-type = ["cdylib"] [features] +cuda = ["voicevox_core/cuda"] directml = ["voicevox_core/directml"] [dependencies]