diff --git a/.gitignore b/.gitignore index 26620d1bd5214..739ec17ca2fce 100644 --- a/.gitignore +++ b/.gitignore @@ -57,3 +57,7 @@ onnxruntime/python/version_info.py # clangd .cache/ compile_commands.json +# Rust specific +rust/**/target +rust/**/Cargo.lock +rust/onnxruntime/synset.txt diff --git a/rust/BUILD.md b/rust/BUILD.md new file mode 100644 index 0000000000000..68500c7fc624a --- /dev/null +++ b/rust/BUILD.md @@ -0,0 +1,48 @@ +# Building and testing the Rust bindings + +These instructions require cargo and rustc. +To get these follow the instructions at [https://rustup.rs](https://rustup.rs) +The instructions compile the onnxruntime along with the bindings, +so require `cmake`, a python 3 interpreter, clang (needed to parse the C headers to generate the Rust bindings), +and the platform compiler to compile onnxruntime. + +## Local setup of onnxruntime repo + +```sh + git clone https://github.com/microsoft/onnxruntime + cd onnxruntime + git submodule update --init --recursive +``` + +## cargo build both crates + +from the root of onnxruntime repo + +```sh + CARGO_TARGET_DIR=build/rust cargo build --manifest-path rust/Cargo.toml +``` + +The CARGO_TARGET_DIR environment variable puts the build artifacts in `onnxruntime/build/rust` +instead of `onnxruntime/rust/target`. + +## cargo test both crates + +```sh + CARGO_TARGET_DIR=build/rust cargo test --manifest-path rust/Cargo.toml --features model-fetching +``` + +### cargo test both crates while specifying the absolute path to the OnnxRuntime shared library. + +```sh + RUST_ONNXRUNTIME_LIBRARY_PATH= CARGO_TARGET_DIR=build/rust cargo test --manifest-path rust/Cargo.toml --features model-fetching +``` + +## cargo test with sanitizer support + +**If you are using a nightly Rust compiler and are on one the platforms listed in [Rust sanitizer support](https://doc.rust-lang.org/beta/unstable-book/compiler-flags/sanitizer.html).** + +where `$SAN` is one of `address`, `thread`, `memory` or `leak` + +```sh + RUSTFLAGS="-Zsanitizer=$SAN" CARGO_TARGET_DIR=build/rust cargo test --manifest-path rust/Cargo.toml --features model-fetching --target -Z build-std -- --test-threads=1 +``` diff --git a/rust/Cargo.toml b/rust/Cargo.toml new file mode 100644 index 0000000000000..7c33647c5d3da --- /dev/null +++ b/rust/Cargo.toml @@ -0,0 +1,5 @@ +[workspace] +members = [ + "onnxruntime-sys", + "onnxruntime", +] diff --git a/rust/LICENSE-APACHE b/rust/LICENSE-APACHE new file mode 100644 index 0000000000000..e0284d8a8d512 --- /dev/null +++ b/rust/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright 2020 Nicolas Bigaouette + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/rust/LICENSE-MIT b/rust/LICENSE-MIT new file mode 100644 index 0000000000000..2b6d07c1daf81 --- /dev/null +++ b/rust/LICENSE-MIT @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Nicolas Bigaouette + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/rust/README.md b/rust/README.md new file mode 100644 index 0000000000000..14b9e8cd632b4 --- /dev/null +++ b/rust/README.md @@ -0,0 +1,196 @@ +# ONNX Runtime + +These are Rust bindings to +[Microsoft's ONNX Runtime](https://github.com/microsoft/onnxruntime). + +This project consists of two crates: + +* [`onnxruntime-sys`](onnxruntime-sys): Low-level binding to the C API; +* [`onnxruntime`](onnxruntime): High-level and safe API. + +The `build.rs` script supports downloading pre-built versions of the Microsoft ONNX Runtime, +which provides the following targets: + +CPU: + +* Linux x86_64 +* macOS x86_64 +* macOS aarch64 +* Windows i686 +* Windows x86_64 + +GPU: + +* Linux x86_64 +* Windows x86_64 + +--- + +**WARNING**: + +* This is an experiment and work in progress; it is _not_ complete/working/safe. Help welcome! +* Basic inference works, see [`onnxruntime/examples/sample.rs`](onnxruntime/examples/sample.rs) or [`onnxruntime/tests/integration_tests.rs`](onnxruntime/tests/integration_tests.rs) +* ONNX Runtime has many options to control the inference process but those options are not yet exposed. + +--- + +## Setup + +Three different strategy to obtain the ONNX Runtime are supported by the `build.rs` script: + +1. Download a pre-built binary from upstream; +2. Point to a local version already installed; +3. Compile from source. + +To select which strategy to use, set the `ORT_RUST_STRATEGY` environment variable to: + +1. `download`: Download prebuilt onnxruntime; +2. `system`: To use a locally installed version (use `ORT_RUST_LIB_LOCATION` environment variable to point to the install path) +3. `compile`: To compile the library. This is the default. + +The `download` strategy supports downloading a version of ONNXRuntime that supports CUDA. To use this, set the +environment variable `ORT_RUST_USE_CUDA=1` (only supports Linux or Windows). + +### Note on 'ORT_RUST_STRATEGY=system' + +When using `ORT_RUST_STRATEGY=system`, executing a built crate binary (for example the tests) might fail, at least on macOS, +if the library is not installed in a system path. An error similar to the following happens: + +```text +dyld: Library not loaded: @rpath/libonnxruntime.1.7.1.dylib + Referenced from: onnxruntime-rs.git/target/debug/deps/onnxruntime_sys-22eb0e3e89a0278c + Reason: image not found +``` + +To fix, one can either: + +* Set the `LD_LIBRARY_PATH` environment variable to point to the path where the library can be found. +* Adapt the `.cargo/config` file to contain a linker flag to provide the **full** path: + + ```toml + [target.aarch64-apple-darwin] + rustflags = ["-C", "link-args=-Wl,-rpath,/full/path/to/onnxruntime/lib"] + ``` + +See [rust-lang/cargo #5077](https://github.com/rust-lang/cargo/issues/5077) for more information. + +## Example + +The C++ example that uses the C API +([`C_Api_Sample.cpp`](https://github.com/microsoft/onnxruntime/blob/v1.3.1/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp)) +was ported to both the low level crate (`onnxruntime-sys`) and the high level on (`onnxruntime`). + +### onnxruntime-sys + +To run this example ([`onnxruntime-sys/examples/c_api_sample.rs`](onnxruntime-sys/examples/c_api_sample.rs)): + +```sh +# Download the model (SqueezeNet 1.0, ONNX version: 1.3, Opset version: 8) +❯ curl -LO "https://github.com/onnx/models/raw/main/vision/classification/squeezenet/model/squeezenet1.0-8.onnx" +❯ cargo run --example c_api_sample +[...] + Finished dev [unoptimized + debuginfo] target(s) in 1.88s + Running `target/debug/examples/c_api_sample` +Using Onnxruntime C API +2020-08-09 09:37:41.554922 [I:onnxruntime:, inference_session.cc:174 ConstructorCommon] Creating and using per session threadpools since use_per_session_threads_ is true +2020-08-09 09:37:41.556650 [I:onnxruntime:, inference_session.cc:830 Initialize] Initializing session. +2020-08-09 09:37:41.556665 [I:onnxruntime:, inference_session.cc:848 Initialize] Adding default CPU execution provider. +2020-08-09 09:37:41.556678 [I:onnxruntime:test, bfc_arena.cc:15 BFCArena] Creating BFCArena for Cpu +2020-08-09 09:37:41.556687 [V:onnxruntime:test, bfc_arena.cc:32 BFCArena] Creating 21 bins of max chunk size 256 to 268435456 +2020-08-09 09:37:41.558313 [I:onnxruntime:, reshape_fusion.cc:37 ApplyImpl] Total fused reshape node count: 0 +2020-08-09 09:37:41.559327 [I:onnxruntime:, reshape_fusion.cc:37 ApplyImpl] Total fused reshape node count: 0 +2020-08-09 09:37:41.559476 [I:onnxruntime:, reshape_fusion.cc:37 ApplyImpl] Total fused reshape node count: 0 +2020-08-09 09:37:41.559607 [V:onnxruntime:, inference_session.cc:671 TransformGraph] Node placements +2020-08-09 09:37:41.559615 [V:onnxruntime:, inference_session.cc:673 TransformGraph] All nodes have been placed on [CPUExecutionProvider]. +2020-08-09 09:37:41.559639 [I:onnxruntime:, session_state.cc:25 SetGraph] SaveMLValueNameIndexMapping +2020-08-09 09:37:41.559787 [I:onnxruntime:, session_state.cc:70 SetGraph] Done saving OrtValue mappings. +2020-08-09 09:37:41.560252 [I:onnxruntime:, session_state_initializer.cc:178 SaveInitializedTensors] Saving initialized tensors. +2020-08-09 09:37:41.563467 [I:onnxruntime:, session_state_initializer.cc:223 SaveInitializedTensors] Done saving initialized tensors +2020-08-09 09:37:41.563979 [I:onnxruntime:, inference_session.cc:919 Initialize] Session successfully initialized. +Number of inputs = 1 +Input 0 : name=data_0 +Input 0 : type=1 +Input 0 : num_dims=4 +Input 0 : dim 0=1 +Input 0 : dim 1=3 +Input 0 : dim 2=224 +Input 0 : dim 3=224 +2020-08-09 09:37:41.573127 [I:onnxruntime:, sequential_executor.cc:145 Execute] Begin execution +2020-08-09 09:37:41.573183 [I:onnxruntime:test, bfc_arena.cc:259 AllocateRawInternal] Extending BFCArena for Cpu. bin_num:13 rounded_bytes:3154176 +2020-08-09 09:37:41.573197 [I:onnxruntime:test, bfc_arena.cc:143 Extend] Extended allocation by 4194304 bytes. +2020-08-09 09:37:41.573203 [I:onnxruntime:test, bfc_arena.cc:147 Extend] Total allocated bytes: 9137152 +2020-08-09 09:37:41.573212 [I:onnxruntime:test, bfc_arena.cc:150 Extend] Allocated memory at 0x7fb7d6cb7000 to 0x7fb7d70b7000 +2020-08-09 09:37:41.573248 [I:onnxruntime:test, bfc_arena.cc:259 AllocateRawInternal] Extending BFCArena for Cpu. bin_num:8 rounded_bytes:65536 +2020-08-09 09:37:41.573256 [I:onnxruntime:test, bfc_arena.cc:143 Extend] Extended allocation by 4194304 bytes. +2020-08-09 09:37:41.573262 [I:onnxruntime:test, bfc_arena.cc:147 Extend] Total allocated bytes: 13331456 +2020-08-09 09:37:41.573268 [I:onnxruntime:test, bfc_arena.cc:150 Extend] Allocated memory at 0x7fb7d70b7000 to 0x7fb7d74b7000 +Score for class [0] = 0.000045440644 +Score for class [1] = 0.0038458651 +Score for class [2] = 0.00012494653 +Score for class [3] = 0.0011804523 +Score for class [4] = 0.0013169361 +Done! +``` + +### onnxruntime + +To run this example ([`onnxruntime/examples/sample.rs`](onnxruntime/examples/sample.rs)): + +```sh +# Download the model (SqueezeNet 1.0, ONNX version: 1.3, Opset version: 8) +❯ curl -LO "https://github.com/onnx/models/raw/main/vision/classification/squeezenet/model/squeezenet1.0-8.onnx" +❯ cargo run --example sample +[...] + Finished dev [unoptimized + debuginfo] target(s) in 13.62s + Running `target/debug/examples/sample` +Uninitialized environment found, initializing it with name "test". +2020-08-09 09:34:37.395577 [I:onnxruntime:, inference_session.cc:174 ConstructorCommon] Creating and using per session threadpools since use_per_session_threads_ is true +2020-08-09 09:34:37.399253 [I:onnxruntime:, inference_session.cc:830 Initialize] Initializing session. +2020-08-09 09:34:37.399284 [I:onnxruntime:, inference_session.cc:848 Initialize] Adding default CPU execution provider. +2020-08-09 09:34:37.399313 [I:onnxruntime:test, bfc_arena.cc:15 BFCArena] Creating BFCArena for Cpu +2020-08-09 09:34:37.399335 [V:onnxruntime:test, bfc_arena.cc:32 BFCArena] Creating 21 bins of max chunk size 256 to 268435456 +2020-08-09 09:34:37.410516 [I:onnxruntime:, reshape_fusion.cc:37 ApplyImpl] Total fused reshape node count: 0 +2020-08-09 09:34:37.417478 [I:onnxruntime:, reshape_fusion.cc:37 ApplyImpl] Total fused reshape node count: 0 +2020-08-09 09:34:37.420131 [I:onnxruntime:, reshape_fusion.cc:37 ApplyImpl] Total fused reshape node count: 0 +2020-08-09 09:34:37.422623 [V:onnxruntime:, inference_session.cc:671 TransformGraph] Node placements +2020-08-09 09:34:37.428863 [V:onnxruntime:, inference_session.cc:673 TransformGraph] All nodes have been placed on [CPUExecutionProvider]. +2020-08-09 09:34:37.428954 [I:onnxruntime:, session_state.cc:25 SetGraph] SaveMLValueNameIndexMapping +2020-08-09 09:34:37.429079 [I:onnxruntime:, session_state.cc:70 SetGraph] Done saving OrtValue mappings. +2020-08-09 09:34:37.429925 [I:onnxruntime:, session_state_initializer.cc:178 SaveInitializedTensors] Saving initialized tensors. +2020-08-09 09:34:37.436300 [I:onnxruntime:, session_state_initializer.cc:223 SaveInitializedTensors] Done saving initialized tensors +2020-08-09 09:34:37.437255 [I:onnxruntime:, inference_session.cc:919 Initialize] Session successfully initialized. +Dropping the session options. +2020-08-09 09:34:37.448956 [I:onnxruntime:, sequential_executor.cc:145 Execute] Begin execution +2020-08-09 09:34:37.449041 [I:onnxruntime:test, bfc_arena.cc:259 AllocateRawInternal] Extending BFCArena for Cpu. bin_num:13 rounded_bytes:3154176 +2020-08-09 09:34:37.449072 [I:onnxruntime:test, bfc_arena.cc:143 Extend] Extended allocation by 4194304 bytes. +2020-08-09 09:34:37.449087 [I:onnxruntime:test, bfc_arena.cc:147 Extend] Total allocated bytes: 9137152 +2020-08-09 09:34:37.449104 [I:onnxruntime:test, bfc_arena.cc:150 Extend] Allocated memory at 0x7fb3b9585000 to 0x7fb3b9985000 +2020-08-09 09:34:37.449176 [I:onnxruntime:test, bfc_arena.cc:259 AllocateRawInternal] Extending BFCArena for Cpu. bin_num:8 rounded_bytes:65536 +2020-08-09 09:34:37.449196 [I:onnxruntime:test, bfc_arena.cc:143 Extend] Extended allocation by 4194304 bytes. +2020-08-09 09:34:37.449209 [I:onnxruntime:test, bfc_arena.cc:147 Extend] Total allocated bytes: 13331456 +2020-08-09 09:34:37.449222 [I:onnxruntime:test, bfc_arena.cc:150 Extend] Allocated memory at 0x7fb3b9985000 to 0x7fb3b9d85000 +Dropping Tensor. +Score for class [0] = 0.000045440578 +Score for class [1] = 0.0038458686 +Score for class [2] = 0.0001249467 +Score for class [3] = 0.0011804511 +Score for class [4] = 0.00131694 +Dropping TensorFromOrt. +Dropping the session. +Dropping the memory information. +Dropping the environment. +``` + +See also the integration tests ([`onnxruntime/tests/integration_tests.rs`](onnxruntime/tests/integration_tests.rs)) +that performs simple model download and inference, validating the results. + +## License + +The Rust bindings are licensed under either of + +* Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or + http://www.apache.org/licenses/LICENSE-2.0) +* MIT license ([LICENSE-MIT](LICENSE-MIT) or + http://opensource.org/licenses/MIT) + +at your option. diff --git a/rust/onnxruntime-sys/Cargo.toml b/rust/onnxruntime-sys/Cargo.toml new file mode 100644 index 0000000000000..4806e6ca2953c --- /dev/null +++ b/rust/onnxruntime-sys/Cargo.toml @@ -0,0 +1,35 @@ +[package] +authors = ["Nicolas Bigaouette "] +edition = "2018" +name = "onnxruntime-sys" +version = "0.0.14" + +links = "onnxruntime" + +description = "Unsafe wrapper around Microsoft's ONNX Runtime" +documentation = "https://docs.rs/onnxruntime-sys" +homepage = "https://github.com/microsoft/onnxruntime" +license = "MIT OR Apache-2.0" +readme = "../README.md" +repository = "https://github.com/microsoft/onnxruntime" + +categories = ["science"] +keywords = ["neuralnetworks", "onnx", "bindings"] + +[dependencies] +libloading = "0.7" + +[build-dependencies] +bindgen = "0.63" +cmake = "0.1" + +# Used on unix +flate2 = "1.0" +tar = "0.4" +ureq = "2.1" + +# Used on Windows +zip = "0.6" + +[features] +default = [] diff --git a/rust/onnxruntime-sys/build.rs b/rust/onnxruntime-sys/build.rs new file mode 100644 index 0000000000000..82d1e4278015c --- /dev/null +++ b/rust/onnxruntime-sys/build.rs @@ -0,0 +1,429 @@ +#![allow(dead_code)] + +use std::{ + borrow::Cow, + env, fs, + io::{self, Read, Write}, + path::{Path, PathBuf}, + str::FromStr, +}; + +/// ONNX Runtime version +/// +/// WARNING: If version is changed, bindings for all platforms will have to be re-generated. +/// To do so, run this: +/// cargo build --package onnxruntime-sys --features generate-bindings +const ORT_VERSION: &str = include_str!("../../VERSION_NUMBER"); + +/// Base Url from which to download pre-built releases/ +const ORT_RELEASE_BASE_URL: &str = "https://github.com/microsoft/onnxruntime/releases/download"; + +/// Environment variable selecting which strategy to use for finding the library +/// Possibilities: +/// * "download": Download a pre-built library. This is the default if `ORT_STRATEGY` is not set. +/// * "system": Use installed library. Use `ORT_LIB_LOCATION` to point to proper location. +/// * "compile": Download source and compile (TODO). +const ORT_RUST_ENV_STRATEGY: &str = "ORT_RUST_STRATEGY"; + +/// Name of environment variable that, if present, contains the location of a pre-built library. +/// Only used if `ORT_STRATEGY=system`. +const ORT_RUST_ENV_SYSTEM_LIB_LOCATION: &str = "ORT_RUST_LIB_LOCATION"; +/// Name of environment variable that, if present, controls whether to use CUDA or not. +const ORT_RUST_ENV_GPU: &str = "ORT_RUST_USE_CUDA"; + +/// Subdirectory (of the 'target' directory) into which to extract the prebuilt library. +const ORT_PREBUILT_EXTRACT_DIR: &str = "onnxruntime"; + +fn main() { + let libort_install_dir = prepare_libort_dir(); + + let include_dir = libort_install_dir.join("include"); + let lib_dir = libort_install_dir.join("lib"); + + println!("Include directory: {:?}", include_dir); + println!("Lib directory: {:?}", lib_dir); + + // Tell cargo to tell rustc to link onnxruntime shared library. + println!("cargo:rustc-link-lib=onnxruntime"); + println!("cargo:rustc-link-search=native={}", lib_dir.display()); + + println!("cargo:rerun-if-env-changed={}", ORT_RUST_ENV_STRATEGY); + println!("cargo:rerun-if-env-changed={}", ORT_RUST_ENV_GPU); + println!( + "cargo:rerun-if-env-changed={}", + ORT_RUST_ENV_SYSTEM_LIB_LOCATION + ); + + generate_bindings(&include_dir); +} + +fn generate_bindings(include_dir: &Path) { + let clang_args = &[ + format!("-I{}", include_dir.display()), + format!( + "-I{}", + include_dir + .join("onnxruntime") + .join("core") + .join("session") + .display() + ), + ]; + + let path = include_dir + .join("onnxruntime") + .join("core") + .join("session") + .join("onnxruntime_c_api.h"); + + // The bindgen::Builder is the main entry point + // to bindgen, and lets you build up options for + // the resulting bindings. + let bindings = bindgen::Builder::default() + // The input header we would like to generate + // bindings for. + .header(path.to_string_lossy().to_string()) + // The current working directory is 'onnxruntime-sys' + .clang_args(clang_args) + // Tell cargo to invalidate the built crate whenever any of the + // included header files changed. + .parse_callbacks(Box::new(bindgen::CargoCallbacks)) + .dynamic_library_name("onnxruntime") + .allowlist_type("Ort.*") + .allowlist_type("Onnx.*") + .allowlist_type("ONNX.*") + .allowlist_function("Ort.*") + .allowlist_var("ORT.*") + // Set `size_t` to be translated to `usize` for win32 compatibility. + .size_t_is_usize(true) + // Format using rustfmt + .rustfmt_bindings(true) + .rustified_enum(".*") + // Finish the builder and generate the bindings. + .generate() + // Unwrap the Result and panic on failure. + .expect("Unable to generate bindings"); + + let generated_file = PathBuf::from(env::var("OUT_DIR").unwrap()).join("bindings.rs"); + println!("cargo:rerun-if-changed={:?}", generated_file); + bindings + .write_to_file(&generated_file) + .expect("Couldn't write bindings!"); +} + +fn download

(source_url: &str, target_file: P) +where + P: AsRef, +{ + let resp = ureq::get(source_url) + .timeout(std::time::Duration::from_secs(300)) + .call() + .unwrap_or_else(|err| panic!("ERROR: Failed to download {}: {:?}", source_url, err)); + + let len = resp + .header("Content-Length") + .and_then(|s| s.parse::().ok()) + .unwrap(); + let mut reader = resp.into_reader(); + // FIXME: Save directly to the file + let mut buffer = vec![]; + let read_len = reader.read_to_end(&mut buffer).unwrap(); + assert_eq!(buffer.len(), len); + assert_eq!(buffer.len(), read_len); + + let f = fs::File::create(&target_file).unwrap(); + let mut writer = io::BufWriter::new(f); + writer.write_all(&buffer).unwrap(); +} + +fn extract_archive(filename: &Path, output: &Path) { + match filename.extension().map(std::ffi::OsStr::to_str) { + Some(Some("zip")) => extract_zip(filename, output), + Some(Some("tgz")) => extract_tgz(filename, output), + _ => unimplemented!(), + } +} + +fn extract_tgz(filename: &Path, output: &Path) { + let file = fs::File::open(&filename).unwrap(); + let buf = io::BufReader::new(file); + let tar = flate2::read::GzDecoder::new(buf); + let mut archive = tar::Archive::new(tar); + archive.unpack(output).unwrap(); +} + +fn extract_zip(filename: &Path, outpath: &Path) { + let file = fs::File::open(&filename).unwrap(); + let buf = io::BufReader::new(file); + let mut archive = zip::ZipArchive::new(buf).unwrap(); + for i in 0..archive.len() { + let mut file = archive.by_index(i).unwrap(); + #[allow(deprecated)] + let outpath = outpath.join(file.sanitized_name()); + if !file.name().ends_with('/') { + println!( + "File {} extracted to \"{}\" ({} bytes)", + i, + outpath.as_path().display(), + file.size() + ); + if let Some(p) = outpath.parent() { + if !p.exists() { + fs::create_dir_all(&p).unwrap(); + } + } + let mut outfile = fs::File::create(&outpath).unwrap(); + io::copy(&mut file, &mut outfile).unwrap(); + } + } +} + +trait OnnxPrebuiltArchive { + fn as_onnx_str(&self) -> Cow; +} + +#[derive(Debug)] +enum Architecture { + X86, + X86_64, + Arm, + Arm64, +} + +impl FromStr for Architecture { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "x86" => Ok(Architecture::X86), + "x86_64" => Ok(Architecture::X86_64), + "arm" => Ok(Architecture::Arm), + "aarch64" => Ok(Architecture::Arm64), + _ => Err(format!("Unsupported architecture: {}", s)), + } + } +} + +impl OnnxPrebuiltArchive for Architecture { + fn as_onnx_str(&self) -> Cow { + match self { + Architecture::X86 => Cow::from("x86"), + Architecture::X86_64 => Cow::from("x64"), + Architecture::Arm => Cow::from("arm"), + Architecture::Arm64 => Cow::from("arm64"), + } + } +} + +#[derive(Debug)] +#[allow(clippy::enum_variant_names)] +enum Os { + Windows, + Linux, + MacOs, +} + +impl Os { + fn archive_extension(&self) -> &'static str { + match self { + Os::Windows => "zip", + Os::Linux => "tgz", + Os::MacOs => "tgz", + } + } +} + +impl FromStr for Os { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "windows" => Ok(Os::Windows), + "macos" => Ok(Os::MacOs), + "linux" => Ok(Os::Linux), + _ => Err(format!("Unsupported os: {}", s)), + } + } +} + +impl OnnxPrebuiltArchive for Os { + fn as_onnx_str(&self) -> Cow { + match self { + Os::Windows => Cow::from("win"), + Os::Linux => Cow::from("linux"), + Os::MacOs => Cow::from("osx"), + } + } +} + +#[derive(Debug, PartialEq, Eq)] +enum Accelerator { + Cpu, + Cuda, +} + +impl FromStr for Accelerator { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "1" | "yes" | "true" | "on" => Ok(Accelerator::Cuda), + _ => Ok(Accelerator::Cpu), + } + } +} + +impl OnnxPrebuiltArchive for Accelerator { + fn as_onnx_str(&self) -> Cow { + match self { + Accelerator::Cpu => Cow::from(""), + Accelerator::Cuda => Cow::from("gpu"), + } + } +} + +#[derive(Debug)] +struct Triplet { + os: Os, + arch: Architecture, + accelerator: Accelerator, +} + +impl OnnxPrebuiltArchive for Triplet { + fn as_onnx_str(&self) -> Cow { + match (&self.os, &self.arch, &self.accelerator) { + // onnxruntime-win-x86-1.11.1.zip + // onnxruntime-win-x64-1.11.1.zip + // onnxruntime-win-arm-1.11.1.zip + // onnxruntime-win-arm64-1.11.1.zip + // onnxruntime-linux-x64-1.11.1.tgz + // onnxruntime-osx-x86_64-1.11.1.tgz + // onnxruntime-osx-arm64-1.11.1.tgz + ( + Os::Windows, + Architecture::X86 | Architecture::X86_64 | Architecture::Arm | Architecture::Arm64, + Accelerator::Cpu, + ) + | (Os::MacOs, Architecture::Arm64, Accelerator::Cpu) + | (Os::Linux, Architecture::X86_64, Accelerator::Cpu) => Cow::from(format!( + "{}-{}", + self.os.as_onnx_str(), + self.arch.as_onnx_str() + )), + (Os::MacOs, Architecture::X86_64, Accelerator::Cpu) => Cow::from(format!( + "{}-x86_{}", + self.os.as_onnx_str(), + self.arch.as_onnx_str().trim_start_matches('x') + )), + // onnxruntime-win-x64-gpu-1.11.1.zip + // onnxruntime-linux-x64-gpu-1.11.1.tgz + (Os::Linux | Os::Windows, Architecture::X86_64, Accelerator::Cuda) => { + Cow::from(format!( + "{}-{}-{}", + self.os.as_onnx_str(), + self.arch.as_onnx_str(), + self.accelerator.as_onnx_str(), + )) + } + _ => { + panic!( + "Unsupported prebuilt triplet: {:?}, {:?}, {:?}. Please use {}=system and {}=/path/to/onnxruntime", + self.os, self.arch, self.accelerator, ORT_RUST_ENV_STRATEGY, ORT_RUST_ENV_SYSTEM_LIB_LOCATION + ); + } + } + } +} + +fn prebuilt_archive_url() -> (PathBuf, String) { + let triplet = Triplet { + os: env::var("CARGO_CFG_TARGET_OS") + .expect("Unable to get TARGET_OS") + .parse() + .unwrap(), + arch: env::var("CARGO_CFG_TARGET_ARCH") + .expect("Unable to get TARGET_ARCH") + .parse() + .unwrap(), + accelerator: env::var(ORT_RUST_ENV_GPU) + .unwrap_or_default() + .parse() + .unwrap(), + }; + + let prebuilt_archive = format!( + "onnxruntime-{}-{}.{}", + triplet.as_onnx_str(), + ORT_VERSION, + triplet.os.archive_extension() + ); + let prebuilt_url = format!( + "{}/v{}/{}", + ORT_RELEASE_BASE_URL, ORT_VERSION, prebuilt_archive + ); + + (PathBuf::from(prebuilt_archive), prebuilt_url) +} + +fn prepare_libort_dir_prebuilt() -> PathBuf { + let (prebuilt_archive, prebuilt_url) = prebuilt_archive_url(); + + let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); + let extract_dir = out_dir.join(ORT_PREBUILT_EXTRACT_DIR); + let downloaded_file = out_dir.join(&prebuilt_archive); + + println!("cargo:rerun-if-changed={}", downloaded_file.display()); + + if !downloaded_file.exists() { + println!("Creating directory {:?}", out_dir); + fs::create_dir_all(&out_dir).unwrap(); + + println!( + "Downloading {} into {}", + prebuilt_url, + downloaded_file.display() + ); + download(&prebuilt_url, &downloaded_file); + } + + if !extract_dir.exists() { + println!("Extracting to {}...", extract_dir.display()); + extract_archive(&downloaded_file, &extract_dir); + } + + extract_dir.join(prebuilt_archive.file_stem().unwrap()) +} + +fn prepare_libort_dir() -> PathBuf { + let strategy = env::var(ORT_RUST_ENV_STRATEGY); + println!( + "strategy: {:?}", + strategy.as_ref().map_or_else(|_| "unknown", String::as_str) + ); + match strategy.as_ref().map(String::as_str) { + Ok("download") => prepare_libort_dir_prebuilt(), + Ok("system") => PathBuf::from(match env::var(ORT_RUST_ENV_SYSTEM_LIB_LOCATION) { + Ok(p) => p, + Err(e) => { + panic!( + "Could not get value of environment variable {:?}: {:?}", + ORT_RUST_ENV_SYSTEM_LIB_LOCATION, e + ); + } + }), + Ok("compile") | Err(_) => prepare_libort_dir_compiled(), + _ => panic!("Unknown value for {:?}", ORT_RUST_ENV_STRATEGY), + } +} + +fn prepare_libort_dir_compiled() -> PathBuf { + let mut config = cmake::Config::new("../../cmake"); + + config.define("onnxruntime_BUILD_SHARED_LIB", "ON"); + + if env::var(ORT_RUST_ENV_GPU).unwrap_or_default().parse() == Ok(Accelerator::Cuda) { + config.define("onnxruntime_USE_CUDA", "ON"); + } + + config.build() +} diff --git a/rust/onnxruntime-sys/examples/c_api_sample.rs b/rust/onnxruntime-sys/examples/c_api_sample.rs new file mode 100644 index 0000000000000..499f1548de396 --- /dev/null +++ b/rust/onnxruntime-sys/examples/c_api_sample.rs @@ -0,0 +1,395 @@ +#![allow(non_snake_case)] + +use std::env::args; +#[cfg(not(target_family = "windows"))] +use std::os::unix::ffi::OsStrExt; +#[cfg(target_family = "windows")] +use std::os::windows::ffi::OsStrExt; + +use onnxruntime_sys::{ + onnxruntime, GraphOptimizationLevel, ONNXTensorElementDataType, OrtAllocator, OrtAllocatorType, + OrtApi, OrtEnv, OrtLoggingLevel, OrtMemType, OrtMemoryInfo, OrtRunOptions, OrtSession, + OrtSessionOptions, OrtStatus, OrtTensorTypeAndShapeInfo, OrtTypeInfo, OrtValue, + ORT_API_VERSION, +}; + +// https://github.com/microsoft/onnxruntime/blob/v1.4.0/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp + +fn main() { + let onnxruntime_path = args() + .nth(1) + .expect("This example expects a path to the ONNXRuntime shared library"); + + let (_, g_ort) = unsafe { + let ort = onnxruntime::new(onnxruntime_path); + + let ort = ort.expect("Error initializing onnxruntime"); + let g_ort = ort.OrtGetApiBase().as_ref().unwrap().GetApi.unwrap()(ORT_API_VERSION); + + (ort, g_ort) + }; + assert_ne!(g_ort, std::ptr::null_mut()); + + //************************************************************************* + // initialize enviroment...one enviroment per process + // enviroment maintains thread pools and other state info + let mut env_ptr: *mut OrtEnv = std::ptr::null_mut(); + let env_name = std::ffi::CString::new("test").unwrap(); + let status = unsafe { + g_ort.as_ref().unwrap().CreateEnv.unwrap()( + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + env_name.as_ptr(), + &mut env_ptr, + ) + }; + CheckStatus(g_ort, status).unwrap(); + assert_ne!(env_ptr, std::ptr::null_mut()); + + // initialize session options if needed + let mut session_options_ptr: *mut OrtSessionOptions = std::ptr::null_mut(); + let status = + unsafe { g_ort.as_ref().unwrap().CreateSessionOptions.unwrap()(&mut session_options_ptr) }; + CheckStatus(g_ort, status).unwrap(); + unsafe { g_ort.as_ref().unwrap().SetIntraOpNumThreads.unwrap()(session_options_ptr, 1) }; + assert_ne!(session_options_ptr, std::ptr::null_mut()); + + // Sets graph optimization level + unsafe { + g_ort + .as_ref() + .unwrap() + .SetSessionGraphOptimizationLevel + .unwrap()( + session_options_ptr, + GraphOptimizationLevel::ORT_ENABLE_BASIC, + ) + }; + + // Optionally add more execution providers via session_options + // E.g. for CUDA include cuda_provider_factory.h and uncomment the following line: + // OrtSessionOptionsAppendExecutionProvider_CUDA(sessionOptions, 0); + + //************************************************************************* + // create session and load model into memory + // NOTE: Original C version loaded SqueezeNet 1.0 (ONNX version: 1.3, Opset version: 8, + // https://github.com/onnx/models/blob/main/vision/classification/squeezenet/model/squeezenet1.0-8.onnx) + // Download it: + // curl -LO "https://github.com/onnx/models/raw/main/vision/classification/squeezenet/model/squeezenet1.0-8.onnx" + // Reference: https://github.com/onnx/models/tree/main/vision/classification/squeezenet#model + let model_path = std::ffi::OsString::from("squeezenet1.0-8.onnx"); + + #[cfg(target_family = "windows")] + let model_path: Vec = model_path + .encode_wide() + .chain(std::iter::once(0)) // Make sure we have a null terminated string + .collect(); + #[cfg(not(target_family = "windows"))] + let model_path: Vec = model_path + .as_bytes() + .iter() + .chain(std::iter::once(&b'\0')) // Make sure we have a null terminated string + .map(|b| *b as std::os::raw::c_char) + .collect(); + + let mut session_ptr: *mut OrtSession = std::ptr::null_mut(); + + println!("Using Onnxruntime C API"); + let status = unsafe { + g_ort.as_ref().unwrap().CreateSession.unwrap()( + env_ptr, + model_path.as_ptr(), + session_options_ptr, + &mut session_ptr, + ) + }; + CheckStatus(g_ort, status).unwrap(); + assert_ne!(session_ptr, std::ptr::null_mut()); + + //************************************************************************* + // print model input layer (node names, types, shape etc.) + // size_t num_input_nodes; + let mut allocator_ptr: *mut OrtAllocator = std::ptr::null_mut(); + let status = unsafe { + g_ort + .as_ref() + .unwrap() + .GetAllocatorWithDefaultOptions + .unwrap()(&mut allocator_ptr) + }; + CheckStatus(g_ort, status).unwrap(); + assert_ne!(allocator_ptr, std::ptr::null_mut()); + + // print number of model input nodes + let mut num_input_nodes: usize = 0; + let status = unsafe { + g_ort.as_ref().unwrap().SessionGetInputCount.unwrap()(session_ptr, &mut num_input_nodes) + }; + CheckStatus(g_ort, status).unwrap(); + assert_ne!(num_input_nodes, 0); + println!("Number of inputs = {:?}", num_input_nodes); + let mut input_node_names: Vec<&str> = Vec::new(); + let mut input_node_dims: Vec = Vec::new(); // simplify... this model has only 1 input node {1, 3, 224, 224}. + // Otherwise need vector> + + // iterate over all input nodes + for i in 0..num_input_nodes { + // print input node names + let mut input_name: *mut i8 = std::ptr::null_mut(); + let status = unsafe { + g_ort.as_ref().unwrap().SessionGetInputName.unwrap()( + session_ptr, + i, + allocator_ptr, + &mut input_name, + ) + }; + CheckStatus(g_ort, status).unwrap(); + assert_ne!(input_name, std::ptr::null_mut()); + + // WARNING: The C function SessionGetInputName allocates memory for the string. + // We cannot let Rust free that string, the C side must free the string. + // We thus convert the pointer to a string slice (&str). + let input_name = char_p_to_str(input_name).unwrap(); + println!("Input {} : name={}", i, input_name); + input_node_names.push(input_name); + + // print input node types + let mut typeinfo_ptr: *mut OrtTypeInfo = std::ptr::null_mut(); + let status = unsafe { + g_ort.as_ref().unwrap().SessionGetInputTypeInfo.unwrap()( + session_ptr, + i, + &mut typeinfo_ptr, + ) + }; + CheckStatus(g_ort, status).unwrap(); + assert_ne!(typeinfo_ptr, std::ptr::null_mut()); + + let mut tensor_info_ptr: *const OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); + let status = unsafe { + g_ort.as_ref().unwrap().CastTypeInfoToTensorInfo.unwrap()( + typeinfo_ptr, + &mut tensor_info_ptr, + ) + }; + CheckStatus(g_ort, status).unwrap(); + assert_ne!(tensor_info_ptr, std::ptr::null_mut()); + + let mut type_: ONNXTensorElementDataType = + ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + let status = unsafe { + g_ort.as_ref().unwrap().GetTensorElementType.unwrap()(tensor_info_ptr, &mut type_) + }; + CheckStatus(g_ort, status).unwrap(); + assert_ne!( + type_, + ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED + ); + + println!("Input {} : type={}", i, type_ as i32); + + // print input shapes/dims + let mut num_dims = 0; + let status = unsafe { + g_ort.as_ref().unwrap().GetDimensionsCount.unwrap()(tensor_info_ptr, &mut num_dims) + }; + CheckStatus(g_ort, status).unwrap(); + assert_ne!(num_dims, 0); + + println!("Input {} : num_dims={}", i, num_dims); + input_node_dims.resize_with(num_dims as usize, Default::default); + let status = unsafe { + g_ort.as_ref().unwrap().GetDimensions.unwrap()( + tensor_info_ptr, + input_node_dims.as_mut_ptr(), + num_dims, + ) + }; + CheckStatus(g_ort, status).unwrap(); + + for j in 0..num_dims { + println!("Input {} : dim {}={}", i, j, input_node_dims[j as usize]); + } + + unsafe { g_ort.as_ref().unwrap().ReleaseTypeInfo.unwrap()(typeinfo_ptr) }; + } + + // Results should be... + // Number of inputs = 1 + // Input 0 : name = data_0 + // Input 0 : type = 1 + // Input 0 : num_dims = 4 + // Input 0 : dim 0 = 1 + // Input 0 : dim 1 = 3 + // Input 0 : dim 2 = 224 + // Input 0 : dim 3 = 224 + + //************************************************************************* + // Similar operations to get output node information. + // Use OrtSessionGetOutputCount(), OrtSessionGetOutputName() + // OrtSessionGetOutputTypeInfo() as shown above. + + //************************************************************************* + // Score the model using sample data, and inspect values + + let input_tensor_size = 224 * 224 * 3; // simplify ... using known dim values to calculate size + // use OrtGetTensorShapeElementCount() to get official size! + + let output_node_names = &["softmaxout_1"]; + + // initialize input data with values in [0.0, 1.0] + let mut input_tensor_values: Vec = (0..input_tensor_size) + .map(|i| (i as f32) / ((input_tensor_size + 1) as f32)) + .collect(); + + // create input tensor object from data values + let mut memory_info_ptr: *mut OrtMemoryInfo = std::ptr::null_mut(); + let status = unsafe { + g_ort.as_ref().unwrap().CreateCpuMemoryInfo.unwrap()( + OrtAllocatorType::OrtArenaAllocator, + OrtMemType::OrtMemTypeDefault, + &mut memory_info_ptr, + ) + }; + CheckStatus(g_ort, status).unwrap(); + assert_ne!(memory_info_ptr, std::ptr::null_mut()); + + // FIXME: Check me! + let mut input_tensor_ptr: *mut OrtValue = std::ptr::null_mut(); + let input_tensor_ptr_ptr: *mut *mut OrtValue = &mut input_tensor_ptr; + let input_tensor_values_ptr: *mut std::ffi::c_void = + input_tensor_values.as_mut_ptr().cast::(); + assert_ne!(input_tensor_values_ptr, std::ptr::null_mut()); + + let shape: *const i64 = input_node_dims.as_ptr(); + assert_ne!(shape, std::ptr::null_mut()); + + let status = unsafe { + g_ort + .as_ref() + .unwrap() + .CreateTensorWithDataAsOrtValue + .unwrap()( + memory_info_ptr, + input_tensor_values_ptr, + input_tensor_size * std::mem::size_of::(), + shape, + 4, + ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + input_tensor_ptr_ptr, + ) + }; + CheckStatus(g_ort, status).unwrap(); + assert_ne!(input_tensor_ptr, std::ptr::null_mut()); + + let mut is_tensor = 0; + let status = + unsafe { g_ort.as_ref().unwrap().IsTensor.unwrap()(input_tensor_ptr, &mut is_tensor) }; + CheckStatus(g_ort, status).unwrap(); + assert_eq!(is_tensor, 1); + + let input_tensor_ptr2: *const OrtValue = input_tensor_ptr as *const OrtValue; + let input_tensor_ptr3: *const *const OrtValue = &input_tensor_ptr2; + + unsafe { g_ort.as_ref().unwrap().ReleaseMemoryInfo.unwrap()(memory_info_ptr) }; + + // score model & input tensor, get back output tensor + + let input_node_names_cstring: Vec = input_node_names + .into_iter() + .map(|n| std::ffi::CString::new(n).unwrap()) + .collect(); + let input_node_names_ptr: Vec<*const i8> = input_node_names_cstring + .into_iter() + .map(|n| n.into_raw() as *const i8) + .collect(); + let input_node_names_ptr_ptr: *const *const i8 = input_node_names_ptr.as_ptr(); + + let output_node_names_cstring: Vec = output_node_names + .iter() + .map(|n| std::ffi::CString::new(n.clone()).unwrap()) + .collect(); + let output_node_names_ptr: Vec<*const i8> = output_node_names_cstring + .iter() + .map(|n| n.as_ptr().cast::()) + .collect(); + let output_node_names_ptr_ptr: *const *const i8 = output_node_names_ptr.as_ptr(); + + let _input_node_names_cstring = + unsafe { std::ffi::CString::from_raw(input_node_names_ptr[0] as *mut i8) }; + let run_options_ptr: *const OrtRunOptions = std::ptr::null(); + let mut output_tensor_ptr: *mut OrtValue = std::ptr::null_mut(); + let output_tensor_ptr_ptr: *mut *mut OrtValue = &mut output_tensor_ptr; + + let status = unsafe { + g_ort.as_ref().unwrap().Run.unwrap()( + session_ptr, + run_options_ptr, + input_node_names_ptr_ptr, + input_tensor_ptr3, + 1, + output_node_names_ptr_ptr, + 1, + output_tensor_ptr_ptr, + ) + }; + CheckStatus(g_ort, status).unwrap(); + assert_ne!(output_tensor_ptr, std::ptr::null_mut()); + + let mut is_tensor = 0; + let status = + unsafe { g_ort.as_ref().unwrap().IsTensor.unwrap()(output_tensor_ptr, &mut is_tensor) }; + CheckStatus(g_ort, status).unwrap(); + assert_eq!(is_tensor, 1); + + // Get pointer to output tensor float values + let mut floatarr: *mut f32 = std::ptr::null_mut(); + let floatarr_ptr: *mut *mut f32 = &mut floatarr; + let floatarr_ptr_void: *mut *mut std::ffi::c_void = + floatarr_ptr.cast::<*mut std::ffi::c_void>(); + let status = unsafe { + g_ort.as_ref().unwrap().GetTensorMutableData.unwrap()(output_tensor_ptr, floatarr_ptr_void) + }; + CheckStatus(g_ort, status).unwrap(); + assert_ne!(floatarr, std::ptr::null_mut()); + + assert!((unsafe { *floatarr.offset(0) } - 0.000_045).abs() < 1e-6); + + // score the model, and print scores for first 5 classes + // NOTE: The C ONNX Runtime allocated the array, we shouldn't drop the vec + // but let C de-allocate instead. + let floatarr_vec: Vec = unsafe { Vec::from_raw_parts(floatarr, 5, 5) }; + for i in 0..5 { + println!("Score for class [{}] = {}", i, floatarr_vec[i]); + } + std::mem::forget(floatarr_vec); + + // Results should be as below... + // Score for class[0] = 0.000045 + // Score for class[1] = 0.003846 + // Score for class[2] = 0.000125 + // Score for class[3] = 0.001180 + // Score for class[4] = 0.001317 + + unsafe { g_ort.as_ref().unwrap().ReleaseValue.unwrap()(output_tensor_ptr) }; + unsafe { g_ort.as_ref().unwrap().ReleaseValue.unwrap()(input_tensor_ptr) }; + unsafe { g_ort.as_ref().unwrap().ReleaseSession.unwrap()(session_ptr) }; + unsafe { g_ort.as_ref().unwrap().ReleaseSessionOptions.unwrap()(session_options_ptr) }; + unsafe { g_ort.as_ref().unwrap().ReleaseEnv.unwrap()(env_ptr) }; + + println!("Done!"); +} + +fn CheckStatus(g_ort: *const OrtApi, status: *const OrtStatus) -> Result<(), String> { + if status != std::ptr::null() { + let raw = unsafe { g_ort.as_ref().unwrap().GetErrorMessage.unwrap()(status) }; + Err(char_p_to_str(raw).unwrap().to_string()) + } else { + Ok(()) + } +} + +fn char_p_to_str<'a>(raw: *const i8) -> Result<&'a str, std::str::Utf8Error> { + let c_str = unsafe { std::ffi::CStr::from_ptr(raw as *mut i8) }; + c_str.to_str() +} diff --git a/rust/onnxruntime-sys/src/lib.rs b/rust/onnxruntime-sys/src/lib.rs new file mode 100644 index 0000000000000..c1ba5c347a036 --- /dev/null +++ b/rust/onnxruntime-sys/src/lib.rs @@ -0,0 +1,15 @@ +#![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] +// Disable clippy and `u128` not being FFI-safe (see #1) +#![allow(clippy::all)] +#![allow(improper_ctypes)] + +include!(concat!(env!("OUT_DIR"), "/bindings.rs")); + +#[cfg(target_os = "windows")] +pub type OnnxEnumInt = i32; +#[cfg(not(target_os = "windows"))] +pub type OnnxEnumInt = u32; + +pub use libloading::library_filename; diff --git a/rust/onnxruntime/Cargo.toml b/rust/onnxruntime/Cargo.toml new file mode 100644 index 0000000000000..d52904c5e50a0 --- /dev/null +++ b/rust/onnxruntime/Cargo.toml @@ -0,0 +1,43 @@ +[package] +authors = ["Nicolas Bigaouette "] +edition = "2018" +name = "onnxruntime" +version = "0.0.14" + +description = "Wrapper around Microsoft's ONNX Runtime" +documentation = "https://docs.rs/onnxruntime" +homepage = "https://onnxruntime.ai/" +license = "MIT OR Apache-2.0" +readme = "../README.md" +repository = "https://github.com/microsoft/onnxruntime" + +categories = ["science"] +keywords = ["neuralnetworks", "onnx", "bindings"] + +[[test]] +name = "integration_tests" +required-features = ["model-fetching"] + +[dependencies] +libloading = "0.7" +ndarray = "0.15" +once_cell = "1.17" +onnxruntime-sys = { version = "0.0.14", path = "../onnxruntime-sys" } +thiserror = "1.0" +tracing = "0.1" + +# Enabled with 'model-fetching' feature +ureq = { version = "2.1", optional = true } + +[dev-dependencies] +image = "0.24" +test-log = { version = "0.2", default-features = false, features = ["trace"] } +tracing-subscriber = "0.2" +ureq = "2.1" + +[features] +# Fetch model from ONNX Model Zoo (https://github.com/onnx/models) +model-fetching = ["ureq"] + +[package.metadata.docs.rs] +features = ["model-fetching"] diff --git a/rust/onnxruntime/examples/issue22.rs b/rust/onnxruntime/examples/issue22.rs new file mode 100644 index 0000000000000..6c96e899fa774 --- /dev/null +++ b/rust/onnxruntime/examples/issue22.rs @@ -0,0 +1,55 @@ +//! Example reproducing issue #22. +//! +//! `model.onnx` available to download here: +//! https://drive.google.com/file/d/1FmL-Wpm06V-8wgRqvV3Skey_X98Ue4D_/view?usp=sharing + +use ndarray::Array2; +use onnxruntime::{environment::Environment, GraphOptimizationLevel, LoggingLevel}; +use std::env::var; +use tracing::Level; +use tracing_subscriber::FmtSubscriber; + +fn main() { + // a builder for `FmtSubscriber`. + let subscriber = FmtSubscriber::builder() + // all spans/events with a level higher than TRACE (e.g, debug, info, warn, etc.) + // will be written to stdout. + .with_max_level(Level::TRACE) + // completes the builder. + .finish(); + + tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed"); + + let path = var("RUST_ONNXRUNTIME_LIBRARY_PATH").ok(); + + let builder = Environment::builder() + .with_name("env") + .with_log_level(LoggingLevel::Warning); + + let builder = if let Some(path) = path.clone() { + builder.with_library_path(path) + } else { + builder + }; + + let env = builder.build().unwrap(); + let session = env + .new_session_builder() + .unwrap() + .with_graph_optimization_level(GraphOptimizationLevel::Basic) + .unwrap() + .with_model_from_file("model.onnx") + .unwrap(); + + println!("{:#?}", session.inputs); + println!("{:#?}", session.outputs); + + let input_ids = Array2::::from_shape_vec((1, 3), vec![1, 2, 3]).unwrap(); + let attention_mask = Array2::::from_shape_vec((1, 3), vec![1, 1, 1]).unwrap(); + + let inputs = vec![input_ids.into(), attention_mask.into()]; + + let outputs = session.run(inputs).unwrap(); + + print!("outputs: {:#?}", outputs[0].float_array().unwrap()); +} diff --git a/rust/onnxruntime/examples/print_structure.rs b/rust/onnxruntime/examples/print_structure.rs new file mode 100644 index 0000000000000..ce38218189616 --- /dev/null +++ b/rust/onnxruntime/examples/print_structure.rs @@ -0,0 +1,47 @@ +//! Display the input and output structure of an ONNX model. +use onnxruntime::{environment, LoggingLevel}; +use std::{env::var, error::Error}; + +fn main() -> Result<(), Box> { + let path = var("RUST_ONNXRUNTIME_LIBRARY_PATH").ok(); + + let builder = environment::Environment::builder() + .with_name("onnx_metadata") + .with_log_level(LoggingLevel::Verbose); + + let builder = if let Some(path) = path.clone() { + builder.with_library_path(path) + } else { + builder + }; + + let environment = builder.build().unwrap(); + + // provide path to .onnx model on disk + let path = std::env::args() + .nth(1) + .expect("Must provide an .onnx file as the first arg"); + + let session = environment + .new_session_builder()? + .with_graph_optimization_level(onnxruntime::GraphOptimizationLevel::Basic)? + .with_model_from_file(path)?; + + println!("Inputs:"); + for (index, input) in session.inputs.iter().enumerate() { + println!( + " {}:\n name = {}\n type = {:?}\n dimensions = {:?}", + index, input.name, input.input_type, input.dimensions + ) + } + + println!("Outputs:"); + for (index, output) in session.outputs.iter().enumerate() { + println!( + " {}:\n name = {}\n type = {:?}\n dimensions = {:?}", + index, output.name, output.output_type, output.dimensions + ); + } + + Ok(()) +} diff --git a/rust/onnxruntime/examples/sample.rs b/rust/onnxruntime/examples/sample.rs new file mode 100644 index 0000000000000..9af5cf733ccae --- /dev/null +++ b/rust/onnxruntime/examples/sample.rs @@ -0,0 +1,83 @@ +#![forbid(unsafe_code)] + +use onnxruntime::{environment::Environment, ndarray::Array, GraphOptimizationLevel, LoggingLevel}; +use std::env::var; +use tracing::Level; +use tracing_subscriber::FmtSubscriber; + +type Error = Box; + +fn main() { + if let Err(e) = run() { + eprintln!("Error: {}", e); + std::process::exit(1); + } +} + +fn run() -> Result<(), Error> { + // Setup the example's log level. + // NOTE: ONNX Runtime's log level is controlled separately when building the environment. + let subscriber = FmtSubscriber::builder() + .with_max_level(Level::TRACE) + .finish(); + + tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed"); + + let path = var("RUST_ONNXRUNTIME_LIBRARY_PATH").ok(); + + let builder = Environment::builder() + .with_name("test") + .with_log_level(LoggingLevel::Warning); + + let builder = if let Some(path) = path.clone() { + builder.with_library_path(path) + } else { + builder + }; + + let environment = builder.build().unwrap(); + + let session = environment + .new_session_builder()? + .with_graph_optimization_level(GraphOptimizationLevel::Basic)? + .with_intra_op_num_threads(1)? + // NOTE: The example uses SqueezeNet 1.0 (ONNX version: 1.3, Opset version: 8), + // _not_ SqueezeNet 1.1 as downloaded by '.with_model_downloaded(ImageClassification::SqueezeNet)' + // Obtain it with: + // curl -LO "https://github.com/onnx/models/raw/main/vision/classification/squeezenet/model/squeezenet1.0-8.onnx" + .with_model_from_file("squeezenet1.0-8.onnx")?; + + let input0_shape: Vec = session.inputs[0] + .dimensions() + .map(std::option::Option::unwrap) + .collect(); + let output0_shape: Vec = session.outputs[0] + .dimensions() + .map(std::option::Option::unwrap) + .collect(); + + assert_eq!(input0_shape, [1, 3, 224, 224]); + assert_eq!(output0_shape, [1, 1000, 1, 1]); + + // initialize input data with values in [0.0, 1.0] + let n: u32 = session.inputs[0] + .dimensions + .iter() + .map(|d| d.unwrap()) + .product(); + let array = Array::linspace(0.0_f32, 1.0, n as usize) + .into_shape(input0_shape) + .unwrap(); + let input_tensor_values = vec![array.into()]; + + let outputs = session.run(input_tensor_values)?; + + let output = outputs[0].float_array().unwrap(); + + assert_eq!(output.shape(), output0_shape.as_slice()); + for i in 0..5 { + println!("Score for class [{}] = {}", i, output[[0, i, 0, 0]]); + } + + Ok(()) +} diff --git a/rust/onnxruntime/src/download.rs b/rust/onnxruntime/src/download.rs new file mode 100644 index 0000000000000..0b600f3786ada --- /dev/null +++ b/rust/onnxruntime/src/download.rs @@ -0,0 +1,113 @@ +//! Module controlling models downloadable from ONNX Model Zoom +//! +//! Pre-trained models are available from the +//! [ONNX Model Zoo](https://github.com/onnx/models). +//! +//! A pre-trained model can be downloaded automatically using the +//! [`SessionBuilder`](../session/struct.SessionBuilder.html)'s +//! [`with_model_downloaded()`](../session/struct.SessionBuilder.html#method.with_model_downloaded) method. +//! +//! See [`AvailableOnnxModel`](enum.AvailableOnnxModel.html) for the different models available +//! to download. + +#[cfg(feature = "model-fetching")] +use std::{ + fs, io, + path::{Path, PathBuf}, + time::Duration, +}; + +#[cfg(feature = "model-fetching")] +use crate::error::{OrtDownloadError, Result}; + +#[cfg(feature = "model-fetching")] +use tracing::info; + +pub mod language; +pub mod vision; + +/// Available pre-trained models to download from [ONNX Model Zoo](https://github.com/onnx/models). +/// +/// According to [ONNX Model Zoo](https://github.com/onnx/models)'s GitHub page: +/// +/// > The ONNX Model Zoo is a collection of pre-trained, state-of-the-art models in the ONNX format +/// > contributed by community members like you. +#[derive(Debug, Clone)] +pub enum AvailableOnnxModel { + /// Computer vision model + Vision(vision::Vision), + /// Natural language model + Language(language::Language), +} + +trait ModelUrl { + fn fetch_url(&self) -> &'static str; +} + +impl ModelUrl for AvailableOnnxModel { + fn fetch_url(&self) -> &'static str { + match self { + AvailableOnnxModel::Vision(model) => model.fetch_url(), + AvailableOnnxModel::Language(model) => model.fetch_url(), + } + } +} + +impl AvailableOnnxModel { + #[cfg(feature = "model-fetching")] + #[tracing::instrument] + pub(crate) fn download_to

(&self, download_dir: P) -> Result + where + P: AsRef + std::fmt::Debug, + { + let url = self.fetch_url(); + + let model_filename = PathBuf::from(url.split('/').last().unwrap()); + let model_filepath = download_dir.as_ref().join(model_filename); + + if model_filepath.exists() { + info!( + model_filepath = format!("{}", model_filepath.display()).as_str(), + "File already exists, not re-downloading.", + ); + Ok(model_filepath) + } else { + info!( + model_filepath = format!("{}", model_filepath.display()).as_str(), + url = format!("{:?}", url).as_str(), + "Downloading file, please wait....", + ); + + let resp = ureq::get(url) + .timeout(Duration::from_secs(180)) // 3 minutes + .call() + .map_err(Box::new) + .map_err(OrtDownloadError::UreqError)?; + + assert!(resp.has("Content-Length")); + let len = resp + .header("Content-Length") + .and_then(|s| s.parse::().ok()) + .unwrap(); + info!(len, "Downloading {} bytes...", len); + + let mut reader = resp.into_reader(); + + let f = fs::File::create(&model_filepath).unwrap(); + let mut writer = io::BufWriter::new(f); + + let bytes_io_count = + io::copy(&mut reader, &mut writer).map_err(OrtDownloadError::IoError)?; + + if bytes_io_count == len as u64 { + Ok(model_filepath) + } else { + Err(OrtDownloadError::CopyError { + expected: len as u64, + io: bytes_io_count, + } + .into()) + } + } + } +} diff --git a/rust/onnxruntime/src/download/language.rs b/rust/onnxruntime/src/download/language.rs new file mode 100644 index 0000000000000..9bf068cf379ef --- /dev/null +++ b/rust/onnxruntime/src/download/language.rs @@ -0,0 +1,25 @@ +//! Module defining natural language models available to download. +//! +//! See [https://github.com/onnx/models#machine_comprehension](https://github.com/onnx/models#machine_comprehension). + +use super::ModelUrl; + +pub mod machine_comprehension; + +// Re-exports +pub use machine_comprehension::MachineComprehension; + +/// Natural language models +#[derive(Debug, Clone)] +pub enum Language { + /// Machine comprehension + MachineComprehension(MachineComprehension), +} + +impl ModelUrl for Language { + fn fetch_url(&self) -> &'static str { + match self { + Language::MachineComprehension(variant) => variant.fetch_url(), + } + } +} diff --git a/rust/onnxruntime/src/download/language/machine_comprehension.rs b/rust/onnxruntime/src/download/language/machine_comprehension.rs new file mode 100644 index 0000000000000..76143aacd8b35 --- /dev/null +++ b/rust/onnxruntime/src/download/language/machine_comprehension.rs @@ -0,0 +1,127 @@ +//! Module defining machine comprehension models available to download. +//! +//! See [https://github.com/onnx/models#machine_comprehension](https://github.com/onnx/models#machine_comprehension) + +// Acronyms are specific ONNX model names and contains upper cases +#![allow(clippy::upper_case_acronyms)] + +use crate::download::{language::Language, AvailableOnnxModel, ModelUrl}; + +/// Machine Comprehension +/// +/// > This subset of natural language processing models that answer questions about a given context paragraph. +/// +/// Source: [https://github.com/onnx/models#machine_comprehension](https://github.com/onnx/models#machine_comprehension) +#[derive(Debug, Clone)] +pub enum MachineComprehension { + /// Answers a query about a given context paragraph. + /// + /// > This model is a neural network for answering a query about a given context paragraph. + /// + /// Source: [https://github.com/onnx/models/tree/main/text/machine_comprehension/bidirectional_attention_flow](https://github.com/onnx/models/tree/main/text/machine_comprehension/bidirectional_attention_flow) + /// + /// Variant downloaded: ONNX Version 1.4 with Opset Version 9. + BiDAF, + /// Answers questions based on the context of the given input paragraph. + /// + /// Source: [https://github.com/onnx/models/tree/main/text/machine_comprehension/bert-squad](https://github.com/onnx/models/tree/main/text/machine_comprehension/bert-squad) + /// + /// Variant downloaded: ONNX Version 1.5 with Opset Version 10. + BERTSquad, + /// Large transformer-based model that predicts sentiment based on given input text. + /// + /// > Transformer-based language model for text generation. + /// + /// Source: [https://github.com/onnx/models/tree/main/text/machine_comprehension/roberta](https://github.com/onnx/models/tree/main/text/machine_comprehension/roberta) + RoBERTa(RoBERTa), + /// Large transformer-based language model that given a sequence of words within some text, predicts the next word. + /// + /// Source: [https://github.com/onnx/models/tree/main/text/machine_comprehension/gpt-2](https://github.com/onnx/models/tree/main/text/machine_comprehension/gpt-2) + GPT2(GPT2), +} + +/// Large transformer-based model that predicts sentiment based on given input text. +/// +/// > Transformer-based language model for text generation. +/// +/// Source: [https://github.com/onnx/models/tree/main/text/machine_comprehension/roberta](https://github.com/onnx/models/tree/main/text/machine_comprehension/roberta) +#[derive(Debug, Clone)] +pub enum RoBERTa { + /// Variant with input is a sequence of words as a string. Example: "Text to encode: Hello, World" + /// + /// Variant downloaded: ONNX Version 1.6 with Opset Version 11. + RoBERTaBase, + /// Variant with input is a sequence of words as a string including sentiment. Example: "This film is so good" + /// + /// Variant downloaded: ONNX Version 1.6 with Opset Version 9. + RoBERTaSequenceClassification, +} + +/// Large transformer-based language model that given a sequence of words within some text, predicts the next word. +/// +/// > Transformer-based language model for text generation. +/// +/// Source: [https://github.com/onnx/models/tree/main/text/machine_comprehension/gpt-2](https://github.com/onnx/models/tree/main/text/machine_comprehension/gpt-2) +/// +/// Variant downloaded: ONNX Version 1.6 with Opset Version 10. +#[derive(Debug, Clone)] +pub enum GPT2 { + /// Pure GPT2 + GPT2, + /// GPT2 + script changes + /// + /// See [https://github.com/onnx/models/blob/main/text/machine_comprehension/gpt-2/dependencies/GPT2-export.py](https://github.com/onnx/models/blob/main/text/machine_comprehension/gpt-2/dependencies/GPT2-export.py) + /// for the script changes. + GPT2LmHead, +} + +impl ModelUrl for MachineComprehension { + fn fetch_url(&self) -> &'static str { + match self { + MachineComprehension::BiDAF => "https://github.com/onnx/models/raw/main/text/machine_comprehension/bidirectional_attention_flow/model/bidaf-9.onnx", + MachineComprehension::BERTSquad => "https://github.com/onnx/models/raw/main/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx", + MachineComprehension::RoBERTa(variant) => variant.fetch_url(), + MachineComprehension::GPT2(variant) => variant.fetch_url(), + } + } +} + +impl ModelUrl for RoBERTa { + fn fetch_url(&self) -> &'static str { + match self { + RoBERTa::RoBERTaBase => "https://github.com/onnx/models/raw/main/text/machine_comprehension/roberta/model/roberta-base-11.onnx", + RoBERTa::RoBERTaSequenceClassification => "https://github.com/onnx/models/raw/main/text/machine_comprehension/roberta/model/roberta-sequence-classification-9.onnx", + } + } +} + +impl ModelUrl for GPT2 { + fn fetch_url(&self) -> &'static str { + match self { + GPT2::GPT2 => "https://github.com/onnx/models/raw/main/text/machine_comprehension/gpt-2/model/gpt2-10.onnx", + GPT2::GPT2LmHead => "https://github.com/onnx/models/raw/main/text/machine_comprehension/gpt-2/model/gpt2-lm-head-10.onnx", + } + } +} + +impl From for AvailableOnnxModel { + fn from(model: MachineComprehension) -> Self { + AvailableOnnxModel::Language(Language::MachineComprehension(model)) + } +} + +impl From for AvailableOnnxModel { + fn from(model: RoBERTa) -> Self { + AvailableOnnxModel::Language(Language::MachineComprehension( + MachineComprehension::RoBERTa(model), + )) + } +} + +impl From for AvailableOnnxModel { + fn from(model: GPT2) -> Self { + AvailableOnnxModel::Language(Language::MachineComprehension(MachineComprehension::GPT2( + model, + ))) + } +} diff --git a/rust/onnxruntime/src/download/vision.rs b/rust/onnxruntime/src/download/vision.rs new file mode 100644 index 0000000000000..bc4d385b46fed --- /dev/null +++ b/rust/onnxruntime/src/download/vision.rs @@ -0,0 +1,45 @@ +//! Module defining computer vision models available to download. +//! +//! See [https://github.com/onnx/models#image_classification](https://github.com/onnx/models#image_classification) + +use super::ModelUrl; + +pub mod body_face_gesture_analysis; +pub mod domain_based_image_classification; +pub mod image_classification; +pub mod image_manipulation; +pub mod object_detection_image_segmentation; + +// Re-exports +pub use body_face_gesture_analysis::BodyFaceGestureAnalysis; +pub use domain_based_image_classification::DomainBasedImageClassification; +pub use image_classification::ImageClassification; +pub use image_manipulation::ImageManipulation; +pub use object_detection_image_segmentation::ObjectDetectionImageSegmentation; + +/// Computer vision model +#[derive(Debug, Clone)] +pub enum Vision { + /// Domain-based Image Classification + DomainBasedImageClassification(DomainBasedImageClassification), + /// Image classification model + ImageClassification(ImageClassification), + /// Object Detection & Image Segmentation + ObjectDetectionImageSegmentation(ObjectDetectionImageSegmentation), + /// Body, Face & Gesture Analysis + BodyFaceGestureAnalysis(BodyFaceGestureAnalysis), + /// Image Manipulation + ImageManipulation(ImageManipulation), +} + +impl ModelUrl for Vision { + fn fetch_url(&self) -> &'static str { + match self { + Vision::DomainBasedImageClassification(variant) => variant.fetch_url(), + Vision::ImageClassification(variant) => variant.fetch_url(), + Vision::ObjectDetectionImageSegmentation(variant) => variant.fetch_url(), + Vision::BodyFaceGestureAnalysis(variant) => variant.fetch_url(), + Vision::ImageManipulation(variant) => variant.fetch_url(), + } + } +} diff --git a/rust/onnxruntime/src/download/vision/body_face_gesture_analysis.rs b/rust/onnxruntime/src/download/vision/body_face_gesture_analysis.rs new file mode 100644 index 0000000000000..1916f85776076 --- /dev/null +++ b/rust/onnxruntime/src/download/vision/body_face_gesture_analysis.rs @@ -0,0 +1,43 @@ +//! Module defining body, face and gesture analysis models available to download. +//! +//! See [https://github.com/onnx/models#body_analysis](https://github.com/onnx/models#body_analysis) + +use crate::download::{vision::Vision, AvailableOnnxModel, ModelUrl}; + +/// Body, Face & Gesture Analysis +/// +/// > Face detection models identify and/or recognize human faces and emotions in given images. Body and Gesture +/// > Analysis models identify gender and age in given image. +/// +/// Source: [https://github.com/onnx/models#body_analysis](https://github.com/onnx/models#body_analysis) +#[derive(Debug, Clone)] +pub enum BodyFaceGestureAnalysis { + /// A CNN based model for face recognition which learns discriminative features of faces and produces + /// embeddings for input face images. + /// + /// Source: [https://github.com/onnx/models/tree/main/vision/body_analysis/arcface](https://github.com/onnx/models/tree/main/vision/body_analysis/arcface) + /// + /// Variant downloaded: ONNX Version 1.3 with Opset Version 8. + ArcFace, + /// Deep CNN for emotion recognition trained on images of faces. + /// + /// Source: [https://github.com/onnx/models/tree/main/vision/body_analysis/emotion_ferplus](https://github.com/onnx/models/tree/main/vision/body_analysis/emotion_ferplus) + /// + /// Variant downloaded: ONNX Version 1.3 with Opset Version 8. + EmotionFerPlus, +} + +impl ModelUrl for BodyFaceGestureAnalysis { + fn fetch_url(&self) -> &'static str { + match self { + BodyFaceGestureAnalysis::ArcFace => "https://github.com/onnx/models/raw/main/vision/body_analysis/arcface/model/arcfaceresnet100-8.onnx", + BodyFaceGestureAnalysis::EmotionFerPlus => "https://github.com/onnx/models/raw/main/vision/body_analysis/emotion_ferplus/model/emotion-ferplus-8.onnx", + } + } +} + +impl From for AvailableOnnxModel { + fn from(model: BodyFaceGestureAnalysis) -> Self { + AvailableOnnxModel::Vision(Vision::BodyFaceGestureAnalysis(model)) + } +} diff --git a/rust/onnxruntime/src/download/vision/domain_based_image_classification.rs b/rust/onnxruntime/src/download/vision/domain_based_image_classification.rs new file mode 100644 index 0000000000000..78387bf175795 --- /dev/null +++ b/rust/onnxruntime/src/download/vision/domain_based_image_classification.rs @@ -0,0 +1,30 @@ +//! Module defining domain-based image classification models available to download. +//! +//! See [https://github.com/onnx/models#domain-based-image-classification-](https://github.com/onnx/models#domain-based-image-classification-) + +use crate::download::{vision::Vision, AvailableOnnxModel, ModelUrl}; + +/// Image classification model +#[derive(Debug, Clone)] +pub enum DomainBasedImageClassification { + /// Handwritten digits prediction using CNN + /// + /// Source: [https://github.com/onnx/models/tree/main/vision/classification/mnist](https://github.com/onnx/models/tree/main/vision/classification/mnist) + /// + /// Variant downloaded: ONNX Version 1.3 with Opset Version 8. + Mnist, +} + +impl ModelUrl for DomainBasedImageClassification { + fn fetch_url(&self) -> &'static str { + match self { + DomainBasedImageClassification::Mnist => "https://github.com/onnx/models/raw/main/vision/classification/mnist/model/mnist-8.onnx", + } + } +} + +impl From for AvailableOnnxModel { + fn from(model: DomainBasedImageClassification) -> Self { + AvailableOnnxModel::Vision(Vision::DomainBasedImageClassification(model)) + } +} diff --git a/rust/onnxruntime/src/download/vision/image_classification.rs b/rust/onnxruntime/src/download/vision/image_classification.rs new file mode 100644 index 0000000000000..7806a75547a42 --- /dev/null +++ b/rust/onnxruntime/src/download/vision/image_classification.rs @@ -0,0 +1,350 @@ +//! Module defining image classification models available to download. +//! +//! See [https://github.com/onnx/models#image_classification](https://github.com/onnx/models#image_classification) + +// Acronyms are specific ONNX model names and contains upper cases +#![allow(clippy::upper_case_acronyms)] + +use crate::download::{vision::Vision, AvailableOnnxModel, ModelUrl}; + +/// Image classification model +/// +/// > This collection of models take images as input, then classifies the major objects in the images +/// > into 1000 object categories such as keyboard, mouse, pencil, and many animals. +/// +/// Source: [https://github.com/onnx/models#image-classification-](https://github.com/onnx/models#image-classification-) +#[derive(Debug, Clone)] +pub enum ImageClassification { + /// Image classification aimed for mobile targets. + /// + /// > MobileNet models perform image classification - they take images as input and classify the major + /// > object in the image into a set of pre-defined classes. They are trained on ImageNet dataset which + /// > contains images from 1000 classes. MobileNet models are also very efficient in terms of speed and + /// > size and hence are ideal for embedded and mobile applications. + /// + /// Source: [https://github.com/onnx/models/tree/main/vision/classification/mobilenet](https://github.com/onnx/models/tree/main/vision/classification/mobilenet) + /// + /// Variant downloaded: ONNX Version 1.2.1 with Opset Version 7. + MobileNet, + /// Image classification, trained on ImageNet with 1000 classes. + /// + /// > ResNet models provide very high accuracies with affordable model sizes. They are ideal for cases when + /// > high accuracy of classification is required. + /// + /// Source: [https://github.com/onnx/models/tree/main/vision/classification/resnet](https://github.com/onnx/models/tree/main/vision/classification/resnet) + ResNet(ResNet), + /// A small CNN with AlexNet level accuracy on ImageNet with 50x fewer parameters. + /// + /// > SqueezeNet is a small CNN which achieves AlexNet level accuracy on ImageNet with 50x fewer parameters. + /// > SqueezeNet requires less communication across servers during distributed training, less bandwidth to + /// > export a new model from the cloud to an autonomous car and more feasible to deploy on FPGAs and other + /// > hardware with limited memory. + /// + /// Source: [https://github.com/onnx/models/tree/main/vision/classification/squeezenet](https://github.com/onnx/models/tree/main/vision/classification/squeezenet) + /// + /// Variant downloaded: SqueezeNet v1.1, ONNX Version 1.2.1 with Opset Version 7. + SqueezeNet, + /// Image classification, trained on ImageNet with 1000 classes. + /// + /// > VGG models provide very high accuracies but at the cost of increased model sizes. They are ideal for + /// > cases when high accuracy of classification is essential and there are limited constraints on model sizes. + /// + /// Source: [https://github.com/onnx/models/tree/main/vision/classification/vgg](https://github.com/onnx/models/tree/main/vision/classification/vgg) + Vgg(Vgg), + /// Convolutional neural network for classification, which competed in the ImageNet Large Scale Visual Recognition Challenge in 2012. + /// + /// Source: [https://github.com/onnx/models/tree/main/vision/classification/alexnet](https://github.com/onnx/models/tree/main/vision/classification/alexnet) + /// + /// Variant downloaded: ONNX Version 1.4 with Opset Version 9. + AlexNet, + /// Convolutional neural network for classification, which competed in the ImageNet Large Scale Visual Recognition Challenge in 2014. + /// + /// Source: [https://github.com/onnx/models/tree/main/vision/classification/inception_and_googlenet/googlenet](https://github.com/onnx/models/tree/main/vision/classification/inception_and_googlenet/googlenet) + /// + /// Variant downloaded: ONNX Version 1.4 with Opset Version 9. + GoogleNet, + /// Variant of AlexNet, it's the name of a convolutional neural network for classification, which competed in the ImageNet Large Scale Visual Recognition Challenge in 2012. + /// + /// Source: [https://github.com/onnx/models/tree/main/vision/classification/caffenet](https://github.com/onnx/models/tree/main/vision/classification/caffenet) + /// + /// Variant downloaded: ONNX Version 1.4 with Opset Version 9. + CaffeNet, + /// Convolutional neural network for detection. + /// + /// > This model was made by transplanting the R-CNN SVM classifiers into a fc-rcnn classification layer. + /// + /// Source: [https://github.com/onnx/models/tree/main/vision/classification/rcnn_ilsvrc13](https://github.com/onnx/models/tree/main/vision/classification/rcnn_ilsvrc13) + /// + /// Variant downloaded: ONNX Version 1.4 with Opset Version 9. + RcnnIlsvrc13, + /// Convolutional neural network for classification. + /// + /// Source: [https://github.com/onnx/models/tree/main/vision/classification/rcnn_ilsvrc13](https://github.com/onnx/models/tree/main/vision/classification/rcnn_ilsvrc13) + /// + /// Variant downloaded: ONNX Version 1.4 with Opset Version 9. + DenseNet121, + /// Google's Inception + Inception(InceptionVersion), + /// Computationally efficient CNN architecture designed specifically for mobile devices with very limited computing power. + /// + /// Source: [https://github.com/onnx/models/tree/main/vision/classification/shufflenet](https://github.com/onnx/models/tree/main/vision/classification/shufflenet) + ShuffleNet(ShuffleNetVersion), + /// Deep convolutional networks for classification. + /// + /// > This model's 4th layer has 512 maps instead of 1024 maps mentioned in the paper. + /// + /// Source: [https://github.com/onnx/models/tree/main/vision/classification/zfnet-512](https://github.com/onnx/models/tree/main/vision/classification/zfnet-512) + ZFNet512, + /// Image classification model that achieves state-of-the-art accuracy. + /// + /// > It is designed to run on mobile CPU, GPU, and EdgeTPU devices, allowing for applications on mobile and loT, where computational resources are limited. + /// + /// Source: [https://github.com/onnx/models/tree/main/vision/classification/efficientnet-lite4](https://github.com/onnx/models/tree/main/vision/classification/efficientnet-lite4) + /// + /// Variant downloaded: ONNX Version 1.7.0 with Opset Version 11. + EfficientNetLite4, +} + +/// Google's Inception +#[derive(Debug, Clone)] +pub enum InceptionVersion { + /// Google's Inception v1 + /// + /// Source: [https://github.com/onnx/models/tree/main/vision/classification/inception_and_googlenet/inception_v1](https://github.com/onnx/models/tree/main/vision/classification/inception_and_googlenet/inception_v1) + /// + /// Variant downloaded: ONNX Version 1.4 with Opset Version 9. + V1, + /// Google's Inception v2 + /// + /// Source: [https://github.com/onnx/models/tree/main/vision/classification/inception_and_googlenet/inception_v2](https://github.com/onnx/models/tree/main/vision/classification/inception_and_googlenet/inception_v2) + /// + /// Variant downloaded: ONNX Version 1.4 with Opset Version 9. + V2, +} + +/// ResNet +/// +/// Source: [https://github.com/onnx/models/tree/main/vision/classification/resnet](https://github.com/onnx/models/tree/main/vision/classification/resnet) +#[derive(Debug, Clone)] +pub enum ResNet { + /// ResNet v1 + V1(ResNetV1), + /// ResNet v2 + V2(ResNetV2), +} +/// ResNet v1 +/// +/// Source: [https://github.com/onnx/models/tree/main/vision/classification/resnet](https://github.com/onnx/models/tree/main/vision/classification/resnet) +#[derive(Debug, Clone)] +pub enum ResNetV1 { + /// ResNet18 + /// + /// Variant downloaded: ONNX Version 1.2.1 with Opset Version 7. + ResNet18, + /// ResNet34 + /// + /// Variant downloaded: ONNX Version 1.2.1 with Opset Version 7. + ResNet34, + /// ResNet50 + /// + /// Variant downloaded: ONNX Version 1.2.1 with Opset Version 7. + ResNet50, + /// ResNet101 + /// + /// Variant downloaded: ONNX Version 1.2.1 with Opset Version 7. + ResNet101, + /// ResNet152 + /// + /// Variant downloaded: ONNX Version 1.2.1 with Opset Version 7. + ResNet152, +} +/// ResNet v2 +/// +/// Source: [https://github.com/onnx/models/tree/main/vision/classification/resnet](https://github.com/onnx/models/tree/main/vision/classification/resnet) +#[derive(Debug, Clone)] +pub enum ResNetV2 { + /// ResNet18 + /// + /// Variant downloaded: ONNX Version 1.2.1 with Opset Version 7. + ResNet18, + /// ResNet34 + /// + /// Variant downloaded: ONNX Version 1.2.1 with Opset Version 7. + ResNet34, + /// ResNet50 + /// + /// Variant downloaded: ONNX Version 1.2.1 with Opset Version 7. + ResNet50, + /// ResNet101 + /// + /// Variant downloaded: ONNX Version 1.2.1 with Opset Version 7. + ResNet101, + /// ResNet152 + /// + /// Variant downloaded: ONNX Version 1.2.1 with Opset Version 7. + ResNet152, +} + +/// ResNet +/// +/// Source: [https://github.com/onnx/models/tree/main/vision/classification/resnet](https://github.com/onnx/models/tree/main/vision/classification/resnet) +#[derive(Debug, Clone)] +pub enum Vgg { + /// VGG with 16 convolutional layers + /// + /// Variant downloaded: ONNX Version 1.2.1 with Opset Version 7. + Vgg16, + /// VGG with 16 convolutional layers, with batch normalization applied after each convolutional layer. + /// + /// The batch normalization leads to better convergence and slightly better accuracies. + /// + /// Variant downloaded: ONNX Version 1.2.1 with Opset Version 7. + Vgg16Bn, + /// VGG with 19 convolutional layers + /// + /// Variant downloaded: ONNX Version 1.2.1 with Opset Version 7. + Vgg19, + /// VGG with 19 convolutional layers, with batch normalization applied after each convolutional layer. + /// + /// The batch normalization leads to better convergence and slightly better accuracies. + /// + /// Variant downloaded: ONNX Version 1.2.1 with Opset Version 7. + Vgg19Bn, +} + +/// Computationally efficient CNN architecture designed specifically for mobile devices with very limited computing power. +/// +/// Source: [https://github.com/onnx/models/tree/main/vision/classification/shufflenet](https://github.com/onnx/models/tree/main/vision/classification/shufflenet) +#[derive(Debug, Clone)] +pub enum ShuffleNetVersion { + /// Source: [https://github.com/onnx/models/tree/main/vision/classification/shufflenet](https://github.com/onnx/models/tree/main/vision/classification/shufflenet) + /// + /// Variant downloaded: ONNX Version 1.4 with Opset Version 9. + V1, + /// ShuffleNetV2 is an improved architecture that is the state-of-the-art in terms of speed and accuracy tradeoff used for image classification. + /// + /// Source: [https://github.com/onnx/models/tree/main/vision/classification/shufflenet](https://github.com/onnx/models/tree/main/vision/classification/shufflenet) + /// + /// Variant downloaded: ONNX Version 1.6 with Opset Version 10. + V2, +} + +impl ModelUrl for ImageClassification { + fn fetch_url(&self) -> &'static str { + match self { + ImageClassification::MobileNet => "https://github.com/onnx/models/raw/main/vision/classification/mobilenet/model/mobilenetv2-7.onnx", + ImageClassification::SqueezeNet => "https://github.com/onnx/models/raw/main/vision/classification/squeezenet/model/squeezenet1.1-7.onnx", + ImageClassification::Inception(version) => version.fetch_url(), + ImageClassification::ResNet(version) => version.fetch_url(), + ImageClassification::Vgg(variant) => variant.fetch_url(), + ImageClassification::AlexNet => "https://github.com/onnx/models/raw/main/vision/classification/alexnet/model/bvlcalexnet-9.onnx", + ImageClassification::GoogleNet => "https://github.com/onnx/models/raw/main/vision/classification/inception_and_googlenet/googlenet/model/googlenet-9.onnx", + ImageClassification::CaffeNet => "https://github.com/onnx/models/raw/main/vision/classification/caffenet/model/caffenet-9.onnx", + ImageClassification::RcnnIlsvrc13 => "https://github.com/onnx/models/raw/main/vision/classification/rcnn_ilsvrc13/model/rcnn-ilsvrc13-9.onnx", + ImageClassification::DenseNet121 => "https://github.com/onnx/models/raw/main/vision/classification/densenet-121/model/densenet-9.onnx", + ImageClassification::ShuffleNet(version) => version.fetch_url(), + ImageClassification::ZFNet512 => "https://github.com/onnx/models/raw/main/vision/classification/zfnet-512/model/zfnet512-9.onnx", + ImageClassification::EfficientNetLite4 => "https://github.com/onnx/models/raw/main/vision/classification/efficientnet-lite4/model/efficientnet-lite4.onnx" + } + } +} + +impl ModelUrl for InceptionVersion { + fn fetch_url(&self) -> &'static str { + match self { + InceptionVersion::V1 => "https://github.com/onnx/models/raw/main/vision/classification/inception_and_googlenet/inception_v1/model/inception-v1-9.onnx", + InceptionVersion::V2 => "https://github.com/onnx/models/raw/main/vision/classification/inception_and_googlenet/inception_v2/model/inception-v2-9.onnx", + } + } +} + +impl ModelUrl for ResNet { + fn fetch_url(&self) -> &'static str { + match self { + ResNet::V1(variant) => variant.fetch_url(), + ResNet::V2(variant) => variant.fetch_url(), + } + } +} + +impl ModelUrl for ResNetV1 { + fn fetch_url(&self) -> &'static str { + match self { + ResNetV1::ResNet18 => "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet18-v1-7.onnx", + ResNetV1::ResNet34 => "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet34-v1-7.onnx", + ResNetV1::ResNet50 => "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet50-v1-7.onnx", + ResNetV1::ResNet101 => "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet101-v1-7.onnx", + ResNetV1::ResNet152 => "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet152-v1-7.onnx", + } + } +} + +impl ModelUrl for ResNetV2 { + fn fetch_url(&self) -> &'static str { + match self { + ResNetV2::ResNet18 => "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet18-v2-7.onnx", + ResNetV2::ResNet34 => "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet34-v2-7.onnx", + ResNetV2::ResNet50 => "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet50-v2-7.onnx", + ResNetV2::ResNet101 => "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet101-v2-7.onnx", + ResNetV2::ResNet152 => "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet152-v2-7.onnx", + } + } +} + +impl ModelUrl for Vgg { + fn fetch_url(&self) -> &'static str { + match self { + Vgg::Vgg16 => "https://github.com/onnx/models/raw/main/vision/classification/vgg/model/vgg16-7.onnx", + Vgg::Vgg16Bn => "https://github.com/onnx/models/raw/main/vision/classification/vgg/model/vgg16-bn-7.onnx", + Vgg::Vgg19 => "https://github.com/onnx/models/raw/main/vision/classification/vgg/model/vgg19-7.onnx", + Vgg::Vgg19Bn => "https://github.com/onnx/models/raw/main/vision/classification/vgg/model/vgg19-bn-7.onnx", + } + } +} + +impl ModelUrl for ShuffleNetVersion { + fn fetch_url(&self) -> &'static str { + match self { + ShuffleNetVersion::V1 => "https://github.com/onnx/models/raw/main/vision/classification/shufflenet/model/shufflenet-9.onnx", + ShuffleNetVersion::V2 => "https://github.com/onnx/models/raw/main/vision/classification/shufflenet/model/shufflenet-v2-10.onnx", + } + } +} + +impl From for AvailableOnnxModel { + fn from(model: ImageClassification) -> Self { + AvailableOnnxModel::Vision(Vision::ImageClassification(model)) + } +} + +impl From for AvailableOnnxModel { + fn from(variant: ResNet) -> Self { + AvailableOnnxModel::Vision(Vision::ImageClassification(ImageClassification::ResNet( + variant, + ))) + } +} + +impl From for AvailableOnnxModel { + fn from(variant: Vgg) -> Self { + AvailableOnnxModel::Vision(Vision::ImageClassification(ImageClassification::Vgg( + variant, + ))) + } +} + +impl From for AvailableOnnxModel { + fn from(variant: InceptionVersion) -> Self { + AvailableOnnxModel::Vision(Vision::ImageClassification(ImageClassification::Inception( + variant, + ))) + } +} + +impl From for AvailableOnnxModel { + fn from(variant: ShuffleNetVersion) -> Self { + AvailableOnnxModel::Vision(Vision::ImageClassification( + ImageClassification::ShuffleNet(variant), + )) + } +} diff --git a/rust/onnxruntime/src/download/vision/image_manipulation.rs b/rust/onnxruntime/src/download/vision/image_manipulation.rs new file mode 100644 index 0000000000000..4a67e429133d1 --- /dev/null +++ b/rust/onnxruntime/src/download/vision/image_manipulation.rs @@ -0,0 +1,86 @@ +//! Module defining image manipulation models available to download. +//! +//! See [https://github.com/onnx/models#image_manipulation](https://github.com/onnx/models#image_manipulation) + +use crate::download::{vision::Vision, AvailableOnnxModel, ModelUrl}; + +/// Image Manipulation +/// +/// > Image manipulation models use neural networks to transform input images to modified output images. Some +/// > popular models in this category involve style transfer or enhancing images by increasing resolution. +/// +/// Source: [https://github.com/onnx/models#image_manipulation](https://github.com/onnx/models#image_manipulation) +#[derive(Debug, Clone)] +pub enum ImageManipulation { + /// Super Resolution + /// + /// > The Super Resolution machine learning model sharpens and upscales the input image to refine the + /// > details and improve quality. + /// + /// Source: [https://github.com/onnx/models/tree/main/vision/super_resolution/sub_pixel_cnn_2016](https://github.com/onnx/models/tree/main/vision/super_resolution/sub_pixel_cnn_2016) + /// + /// Variant downloaded: ONNX Version 1.5 with Opset Version 10. + SuperResolution, + /// Fast Neural Style Transfer + /// + /// > This artistic style transfer model mixes the content of an image with the style of another image. + /// > Examples of the styles can be seen + /// > [in this PyTorch example](https://github.com/pytorch/examples/tree/main/fast_neural_style#models). + /// + /// Source: [https://github.com/onnx/models/tree/main/vision/style_transfer/fast_neural_style](https://github.com/onnx/models/tree/main/vision/style_transfer/fast_neural_style) + FastNeuralStyleTransfer(FastNeuralStyleTransferStyle), +} + +/// Fast Neural Style Transfer Style +/// +/// Source: [https://github.com/onnx/models/tree/main/vision/style_transfer/fast_neural_style](https://github.com/onnx/models/tree/main/vision/style_transfer/fast_neural_style) +/// +/// Variant downloaded: ONNX Version 1.4 with Opset Version 9. +#[derive(Debug, Clone)] +pub enum FastNeuralStyleTransferStyle { + /// Mosaic style + Mosaic, + /// Candy style + Candy, + /// RainPrincess style + RainPrincess, + /// Udnie style + Udnie, + /// Pointilism style + Pointilism, +} + +impl ModelUrl for ImageManipulation { + fn fetch_url(&self) -> &'static str { + match self { + ImageManipulation::SuperResolution => "https://github.com/onnx/models/raw/main/vision/super_resolution/sub_pixel_cnn_2016/model/super-resolution-10.onnx", + ImageManipulation::FastNeuralStyleTransfer(style) => style.fetch_url(), + } + } +} + +impl ModelUrl for FastNeuralStyleTransferStyle { + fn fetch_url(&self) -> &'static str { + match self { + FastNeuralStyleTransferStyle::Mosaic => "https://github.com/onnx/models/raw/main/vision/style_transfer/fast_neural_style/model/mosaic-9.onnx", + FastNeuralStyleTransferStyle::Candy => "https://github.com/onnx/models/raw/main/vision/style_transfer/fast_neural_style/model/candy-9.onnx", + FastNeuralStyleTransferStyle::RainPrincess => "https://github.com/onnx/models/raw/main/vision/style_transfer/fast_neural_style/model/rain-princess-9.onnx", + FastNeuralStyleTransferStyle::Udnie => "https://github.com/onnx/models/raw/main/vision/style_transfer/fast_neural_style/model/udnie-9.onnx", + FastNeuralStyleTransferStyle::Pointilism => "https://github.com/onnx/models/raw/main/vision/style_transfer/fast_neural_style/model/pointilism-9.onnx", + } + } +} + +impl From for AvailableOnnxModel { + fn from(model: ImageManipulation) -> Self { + AvailableOnnxModel::Vision(Vision::ImageManipulation(model)) + } +} + +impl From for AvailableOnnxModel { + fn from(style: FastNeuralStyleTransferStyle) -> Self { + AvailableOnnxModel::Vision(Vision::ImageManipulation( + ImageManipulation::FastNeuralStyleTransfer(style), + )) + } +} diff --git a/rust/onnxruntime/src/download/vision/object_detection_image_segmentation.rs b/rust/onnxruntime/src/download/vision/object_detection_image_segmentation.rs new file mode 100644 index 0000000000000..ff95154c20c21 --- /dev/null +++ b/rust/onnxruntime/src/download/vision/object_detection_image_segmentation.rs @@ -0,0 +1,107 @@ +//! Module defining object detection and image segmentation models available to download. +//! +//! See [https://github.com/onnx/models#object_detection](https://github.com/onnx/models#object_detection) + +// Acronyms are specific ONNX model names and contains upper cases +#![allow(clippy::upper_case_acronyms)] + +use crate::download::{vision::Vision, AvailableOnnxModel, ModelUrl}; + +/// Object Detection & Image Segmentation +/// +/// > Object detection models detect the presence of multiple objects in an image and segment out areas of the +/// > image where the objects are detected. Semantic segmentation models partition an input image by labeling each pixel +/// > into a set of pre-defined categories. +/// +/// Source: [https://github.com/onnx/models#object_detection](https://github.com/onnx/models#object_detection) +#[derive(Debug, Clone)] +pub enum ObjectDetectionImageSegmentation { + /// A real-time CNN for object detection that detects 20 different classes. A smaller version of the + /// more complex full YOLOv2 network. + /// + /// Variant downloaded: ONNX Version 1.3 with Opset Version 8. + TinyYoloV2, + /// Single Stage Detector: real-time CNN for object detection that detects 80 different classes. + /// + /// Variant downloaded: ONNX Version 1.5 with Opset Version 10. + Ssd, + /// A variant of MobileNet that uses the Single Shot Detector (SSD) model framework. The model detects 80 + /// different object classes and locates up to 10 objects in an image. + /// + /// Variant downloaded: ONNX Version 1.7.0 with Opset Version 10. + SSDMobileNetV1, + /// Increases efficiency from R-CNN by connecting a RPN with a CNN to create a single, unified network for + /// object detection that detects 80 different classes. + /// + /// Variant downloaded: ONNX Version 1.5 with Opset Version 10. + FasterRcnn, + /// A real-time neural network for object instance segmentation that detects 80 different classes. Extends + /// Faster R-CNN as each of the 300 elected ROIs go through 3 parallel branches of the network: label + /// prediction, bounding box prediction and mask prediction. + /// + /// Variant downloaded: ONNX Version 1.5 with Opset Version 10. + MaskRcnn, + /// A real-time dense detector network for object detection that addresses class imbalance through Focal Loss. + /// RetinaNet is able to match the speed of previous one-stage detectors and defines the state-of-the-art in + /// two-stage detectors (surpassing R-CNN). + /// + /// Variant downloaded: ONNX Version 1.6.0 with Opset Version 9. + RetinaNet, + /// A CNN model for real-time object detection system that can detect over 9000 object categories. It uses a + /// single network evaluation, enabling it to be more than 1000x faster than R-CNN and 100x faster than + /// Faster R-CNN. + /// + /// Variant downloaded: ONNX Version 1.3 with Opset Version 8. + YoloV2, + /// A CNN model for real-time object detection system that can detect over 9000 object categories. It uses + /// a single network evaluation, enabling it to be more than 1000x faster than R-CNN and 100x faster than + /// Faster R-CNN. This model is trained with COCO dataset and contains 80 classes. + /// + /// Variant downloaded: ONNX Version 1.5 with Opset Version 9. + YoloV2Coco, + /// A deep CNN model for real-time object detection that detects 80 different classes. A little bigger than + /// YOLOv2 but still very fast. As accurate as SSD but 3 times faster. + /// + /// Variant downloaded: ONNX Version 1.5 with Opset Version 10. + YoloV3, + /// A smaller version of YOLOv3 model. + /// + /// Variant downloaded: ONNX Version 1.6 with Opset Version 11. + TinyYoloV3, + /// Optimizes the speed and accuracy of object detection. Two times faster than EfficientDet. It improves + /// YOLOv3's AP and FPS by 10% and 12%, respectively, with mAP50 of 52.32 on the COCO 2017 dataset and + /// FPS of 41.7 on Tesla 100. + /// + /// Variant downloaded: ONNX Version 1.6 with Opset Version 11. + YoloV4, + /// Deep CNN based pixel-wise semantic segmentation model with >80% mIOU (mean Intersection Over Union). + /// Trained on cityscapes dataset, which can be effectively implemented in self driving vehicle systems. + /// + /// Variant downloaded: ONNX Version 1.2.2 with Opset Version 7. + Duc, +} + +impl ModelUrl for ObjectDetectionImageSegmentation { + fn fetch_url(&self) -> &'static str { + match self { + ObjectDetectionImageSegmentation::TinyYoloV2 => "https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/tiny-yolov2/model/tinyyolov2-8.onnx", + ObjectDetectionImageSegmentation::Ssd => "https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/ssd/model/ssd-10.onnx", + ObjectDetectionImageSegmentation::SSDMobileNetV1 => "https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/ssd-mobilenetv1/model/ssd_mobilenet_v1_10.onnx", + ObjectDetectionImageSegmentation::FasterRcnn => "https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/faster-rcnn/model/FasterRCNN-10.onnx", + ObjectDetectionImageSegmentation::MaskRcnn => "https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/mask-rcnn/model/MaskRCNN-10.onnx", + ObjectDetectionImageSegmentation::RetinaNet => "https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/retinanet/model/retinanet-9.onnx", + ObjectDetectionImageSegmentation::YoloV2 => "https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/yolov2/model/yolov2-voc-8.onnx", + ObjectDetectionImageSegmentation::YoloV2Coco => "https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/yolov2-coco/model/yolov2-coco-9.onnx", + ObjectDetectionImageSegmentation::YoloV3 => "https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/yolov3/model/yolov3-10.onnx", + ObjectDetectionImageSegmentation::TinyYoloV3 => "https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/tiny-yolov3/model/tiny-yolov3-11.onnx", + ObjectDetectionImageSegmentation::YoloV4 => "https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/yolov4/model/yolov4.onnx", + ObjectDetectionImageSegmentation::Duc => "https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/duc/model/ResNet101-DUC-7.onnx", + } + } +} + +impl From for AvailableOnnxModel { + fn from(model: ObjectDetectionImageSegmentation) -> Self { + AvailableOnnxModel::Vision(Vision::ObjectDetectionImageSegmentation(model)) + } +} diff --git a/rust/onnxruntime/src/environment.rs b/rust/onnxruntime/src/environment.rs new file mode 100644 index 0000000000000..04c34ab38c7b9 --- /dev/null +++ b/rust/onnxruntime/src/environment.rs @@ -0,0 +1,373 @@ +//! Module containing environment types + +use crate::{ + error::{status_to_result, OrtError, Result}, + onnxruntime::custom_logger, + session::SessionBuilder, + LoggingLevel, +}; +use once_cell::sync::OnceCell; +use onnxruntime_sys as sys; +use onnxruntime_sys::library_filename; +use std::{ + ffi::CString, + ptr::{null, null_mut}, + sync::{Arc, Mutex, MutexGuard}, +}; +use sys::{onnxruntime, ORT_API_VERSION}; +use tracing::{debug, warn}; + +pub(crate) static ENV: OnceCell>> = OnceCell::new(); + +pub(crate) static LIB: OnceCell = OnceCell::new(); + +#[derive(Debug)] +pub(crate) struct _EnvironmentSingleton { + name: CString, + pub(crate) env_ptr: *mut sys::OrtEnv, + + pub api: *const sys::OrtApi, +} + +impl _EnvironmentSingleton { + pub(crate) unsafe fn api(&self) -> sys::OrtApi { + *self.api + } +} + +unsafe impl Send for _EnvironmentSingleton {} + +unsafe impl Sync for _EnvironmentSingleton {} + +/// An [`Environment`](session/struct.Environment.html) is the main entry point of the ONNX Runtime. +/// +/// Only one ONNXRuntime environment can be created per process. The `onnxruntime` crate +/// uses a singleton (through `lazy_static!()`) to enforce this. +/// +/// Once an environment is created, a [`Session`](../session/struct.Session.html) +/// can be obtained from it. +/// +/// **NOTE**: While the [`Environment`](environment/struct.Environment.html) constructor takes a `name` parameter +/// to name the environment, only the first name will be considered if many environments +/// are created. +/// +/// # Example +/// +/// ```no_run +/// # use std::error::Error; +/// # use std::env::var; +/// # use onnxruntime::{environment::Environment, LoggingLevel}; +/// # fn main() -> Result<(), Box> { +/// # let path = var("RUST_ONNXRUNTIME_LIBRARY_PATH").ok(); +/// +/// let builder = Environment::builder() +/// .with_name("test") +/// .with_log_level(LoggingLevel::Warning); +/// +/// let builder = if let Some(path) = path { +/// builder.with_library_path(path) +/// } else { +/// builder +/// }; +/// let environment = builder.build()?; +/// # Ok(()) +/// # } +/// ``` +pub struct Environment { + pub(crate) env: _Environment, +} + +#[derive(Debug, Clone)] +pub(crate) struct _Environment { + env: Arc>, +} + +impl _Environment { + pub(crate) fn env(&self) -> MutexGuard<_EnvironmentSingleton> { + self.env.lock().expect("The lock is poisoned") + } +} + +impl std::fmt::Debug for Environment { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.env.fmt(f) + } +} + +impl Environment { + /// Create a new environment builder using default values + /// (name: `default`, log level: [`LoggingLevel::Warning`](../enum.LoggingLevel.html#variant.Warning)) + #[must_use] + pub fn builder() -> EnvBuilder { + EnvBuilder { + name: "default".into(), + log_level: LoggingLevel::Warning, + path: None, + } + } + + /// Return the name of the current environment + #[must_use] + pub fn name(&self) -> String { + self.env().name.to_str().unwrap().to_string() + } + + pub(crate) fn env(&self) -> MutexGuard<_EnvironmentSingleton> { + self.env.env() + } + + #[tracing::instrument] + fn new(name: &str, log_level: LoggingLevel, path: Option) -> Result { + let lib = if let Some(path) = path { + LIB.get_or_try_init(|| unsafe { onnxruntime::new(path) })? + } else { + LIB.get_or_try_init(|| unsafe { onnxruntime::new(library_filename("onnxruntime")) })? + }; + let env = ENV.get_or_try_init(|| { + debug!("Environment not yet initialized, creating a new one."); + + let api = unsafe { (*lib.OrtGetApiBase()).GetApi.unwrap()(ORT_API_VERSION) }; + + let mut env_ptr: *mut sys::OrtEnv = std::ptr::null_mut(); + + let logging_function: sys::OrtLoggingFunction = Some(custom_logger); + // FIXME: What should go here? + let logger_param: *mut std::ffi::c_void = std::ptr::null_mut(); + + let cname = CString::new(name).unwrap(); + unsafe { + let create_env_with_custom_logger = (*api).CreateEnvWithCustomLogger.unwrap(); + let status = create_env_with_custom_logger( + logging_function, + logger_param, + log_level.into(), + cname.as_ptr(), + &mut env_ptr, + ); + + status_to_result(status).map_err(OrtError::Environment)?; + } + debug!( + env_ptr = format!("{:?}", env_ptr).as_str(), + "Environment created." + ); + + Ok::<_, OrtError>(Arc::new(Mutex::new(_EnvironmentSingleton { + name: cname, + env_ptr, + api, + }))) + })?; + + let mut guard = env.lock().expect("Lock is poisoned"); + + if guard.env_ptr.is_null() || guard.api.is_null() { + debug!("Environment not yet initialized, creating a new one."); + + let api = unsafe { (*lib.OrtGetApiBase()).GetApi.unwrap()(ORT_API_VERSION) }; + + let mut env_ptr: *mut sys::OrtEnv = std::ptr::null_mut(); + + let logging_function: sys::OrtLoggingFunction = Some(custom_logger); + // FIXME: What should go here? + let logger_param: *mut std::ffi::c_void = std::ptr::null_mut(); + + let cname = CString::new(name).unwrap(); + unsafe { + let create_env_with_custom_logger = (*api).CreateEnvWithCustomLogger.unwrap(); + let status = create_env_with_custom_logger( + logging_function, + logger_param, + log_level.into(), + cname.as_ptr(), + &mut env_ptr, + ); + + status_to_result(status).map_err(OrtError::Environment)?; + } + debug!( + env_ptr = format!("{:?}", env_ptr).as_str(), + "Environment created." + ); + + guard.env_ptr = env_ptr; + guard.api = api; + guard.name = cname; + } + + Ok(Environment { + env: _Environment { env: env.clone() }, + }) + } + + /// Create a new [`SessionBuilder`](../session/struct.SessionBuilder.html) + /// used to create a new ONNXRuntime session. + pub fn new_session_builder(&self) -> Result { + SessionBuilder::new(self) + } +} + +impl Drop for Environment { + fn drop(&mut self) { + if Arc::strong_count(ENV.get().unwrap()) == 2 { + let env = &mut *ENV.get().unwrap().lock().expect("Lock is poisoned"); + + unsafe { + let release_env = env.api().ReleaseEnv.unwrap(); + release_env(env.env_ptr); + + env.api = null(); + + env.env_ptr = null_mut(); + env.name = CString::default(); + }; + } + } +} + +/// Struct used to build an environment [`Environment`](environment/struct.Environment.html) +/// +/// This is the crate's main entry point. An environment _must_ be created +/// as the first step. An [`Environment`](environment/struct.Environment.html) can only be built +/// using `EnvBuilder` to configure it. +/// +/// **NOTE**: If the same configuration method (for example [`with_name()`](struct.EnvBuilder.html#method.with_name)) +/// is called multiple times, the last value will have precedence. +pub struct EnvBuilder { + name: String, + log_level: LoggingLevel, + path: Option, +} + +impl EnvBuilder { + /// Configure the environment with a given name + /// + /// **NOTE**: Since ONNXRuntime can only define one environment per process, + /// creating multiple environments using multiple `EnvBuilder` will + /// end up re-using the same environment internally; a new one will _not_ + /// be created. New parameters will be ignored. + pub fn with_name(mut self, name: S) -> EnvBuilder + where + S: Into, + { + self.name = name.into(); + self + } + + /// Add a library path to the Onnxruntime shared library. + /// + /// **Note**: The library path can be an absolute path or relative (to the executable) path. + /// If no library path is specified, it is expected that the OS can find the Onnxruntime shared + /// library in the normal manner to that OS. + pub fn with_library_path>(mut self, path: P) -> EnvBuilder { + self.path = Some(path.into()); + self + } + + /// Configure the environment with a given log level + /// + /// **NOTE**: Since ONNXRuntime can only define one environment per process, + /// creating multiple environments using multiple `EnvBuilder` will + /// end up re-using the same environment internally; a new one will _not_ + /// be created. New parameters will be ignored. + #[must_use] + pub fn with_log_level(mut self, log_level: LoggingLevel) -> EnvBuilder { + self.log_level = log_level; + self + } + + /// Commit the configuration to a new [`Environment`](environment/struct.Environment.html) + pub fn build(self) -> Result { + Environment::new(&self.name, self.log_level, self.path) + } +} + +#[cfg(test)] +pub(crate) mod tests { + use std::env::var; + + use super::*; + use test_log::test; + + pub(crate) static ONNX_RUNTIME_LIBRARY_PATH: &str = "RUST_ONNXRUNTIME_LIBRARY_PATH"; + + #[test] + fn sequential_environment_creation() { + let first_name: String = "sequential_environment_creation".into(); + + let path = var(ONNX_RUNTIME_LIBRARY_PATH).ok(); + + let builder = Environment::builder() + .with_name(first_name.clone()) + .with_log_level(LoggingLevel::Warning); + + let builder = if let Some(path) = path.clone() { + builder.with_library_path(path) + } else { + builder + }; + + let env = builder.build().unwrap(); + + let mut prev_env_ptr = env.env().env_ptr; + + for i in 0..10 { + let name = format!("sequential_environment_creation: {}", i); + let builder = Environment::builder() + .with_name(name.clone()) + .with_log_level(LoggingLevel::Warning); + + let builder = if let Some(ref path) = path { + builder.with_library_path(path) + } else { + builder + }; + + let env = builder.build().unwrap(); + let next_env_ptr = env.env().env_ptr; + assert_eq!(next_env_ptr, prev_env_ptr); + prev_env_ptr = next_env_ptr; + } + } + + #[test] + fn concurrent_environment_creations() { + let initial_name = "concurrent_environment_creation"; + + let path = var(ONNX_RUNTIME_LIBRARY_PATH).ok(); + + let main_env = Environment::new(initial_name, LoggingLevel::Warning, path.clone()).unwrap(); + let main_env_ptr = main_env.env().env_ptr as usize; + + let children: Vec<_> = (0..10) + .map(|t| { + let path = path.clone(); + + std::thread::spawn(move || { + let name = format!("concurrent_environment_creation: {}", t); + let builder = Environment::builder() + .with_name(name.clone()) + .with_log_level(LoggingLevel::Warning); + + let builder = if let Some(path) = path { + builder.with_library_path(path) + } else { + builder + }; + + let env = builder.build().unwrap(); + + assert_eq!(env.env().env_ptr as usize, main_env_ptr); + }) + }) + .collect(); + + assert_eq!(main_env.env().env_ptr as usize, main_env_ptr); + + let res: Vec> = children + .into_iter() + .map(std::thread::JoinHandle::join) + .collect(); + assert!(res.into_iter().all(|r| std::result::Result::is_ok(&r))); + } +} diff --git a/rust/onnxruntime/src/error.rs b/rust/onnxruntime/src/error.rs new file mode 100644 index 0000000000000..fc44e2b33930e --- /dev/null +++ b/rust/onnxruntime/src/error.rs @@ -0,0 +1,249 @@ +//! Module containing error definitions. + +use std::{io, path::PathBuf}; + +use thiserror::Error; + +use onnxruntime_sys as sys; + +use crate::{char_p_to_string, environment::ENV}; + +/// Type alias for the `Result` +pub type Result = std::result::Result; + +/// Error type centralizing all possible errors +#[non_exhaustive] +#[derive(Error, Debug)] +pub enum OrtError { + /// For errors with libloading + #[error("Failed to load or call onnxruntime library {0}")] + Library(#[from] libloading::Error), + /// The C API can message to the caller using a C `char *` which needs to be converted + /// to Rust's `String`. This operation can fail. + #[error("Failed to construct String")] + StringConversion(OrtApiError), + // FIXME: Move these to another enum (they are C API calls errors) + /// An error occurred when creating an ONNXRuntime environment + #[error("Failed to create environment: {0}")] + Environment(OrtApiError), + /// Error occurred when creating an ONNXRuntime session options + #[error("Failed to create session options: {0}")] + SessionOptions(OrtApiError), + /// Error occurred when creating an ONNXRuntime session + #[error("Failed to create session: {0}")] + Session(OrtApiError), + /// Error occurred when creating an ONNXRuntime allocator + #[error("Failed to get allocator: {0}")] + Allocator(OrtApiError), + /// Error occurred when counting ONNXRuntime input or output count + #[error("Failed to get input or output count: {0}")] + InOutCount(OrtApiError), + /// Error occurred when getting ONNXRuntime input name + #[error("Failed to get input name: {0}")] + InputName(OrtApiError), + /// Error occurred when getting ONNXRuntime type information + #[error("Failed to get type info: {0}")] + GetTypeInfo(OrtApiError), + /// Error occurred when casting ONNXRuntime type information to tensor information + #[error("Failed to cast type info to tensor info: {0}")] + CastTypeInfoToTensorInfo(OrtApiError), + /// Error occurred when getting tensor elements type + #[error("Failed to get tensor element type: {0}")] + TensorElementType(OrtApiError), + /// Error occurred when getting ONNXRuntime dimensions count + #[error("Failed to get dimensions count: {0}")] + GetDimensionsCount(OrtApiError), + /// Error occurred when getting ONNXRuntime dimensions + #[error("Failed to get dimensions: {0}")] + GetDimensions(OrtApiError), + /// Error occurred when creating CPU memory information + #[error("Failed to get dimensions: {0}")] + CreateCpuMemoryInfo(OrtApiError), + /// Error occurred when creating ONNXRuntime tensor + #[error("Failed to create tensor: {0}")] + CreateTensor(OrtApiError), + /// Error occurred when creating ONNXRuntime tensor with specific data + #[error("Failed to create tensor with data: {0}")] + CreateTensorWithData(OrtApiError), + /// Error occurred when filling a tensor with string data + #[error("Failed to fill string tensor: {0}")] + FillStringTensor(OrtApiError), + /// Error occurred when checking if ONNXRuntime tensor was properly initialized + #[error("Failed to check if tensor: {0}")] + IsTensor(OrtApiError), + /// Error occurred when getting tensor type and shape + #[error("Failed to get tensor type and shape: {0}")] + GetTensorTypeAndShape(OrtApiError), + /// Error occurred when ONNXRuntime inference operation was called + #[error("Failed to run: {0}")] + Run(OrtApiError), + /// Error occurred when extracting data from an ONNXRuntime tensor into an C array to be used as an `ndarray::ArrayView` + #[error("Failed to get tensor data: {0}")] + GetTensorMutableData(OrtApiError), + + /// Error occurred when downloading a pre-trained ONNX model from the [ONNX Model Zoo](https://github.com/onnx/models) + #[error("Failed to download ONNX model: {0}")] + DownloadError(#[from] OrtDownloadError), + + /// Dimensions of input data and ONNX model loaded from file do not match + #[error("Dimensions do not match: {0:?}")] + NonMatchingDimensions(NonMatchingDimensionsError), + /// File does not exists + #[error("File {filename:?} does not exists")] + FileDoesNotExists { + /// Path which does not exists + filename: PathBuf, + }, + /// Path is an invalid UTF-8 + #[error("Path {path:?} cannot be converted to UTF-8")] + NonUtf8Path { + /// Path with invalid UTF-8 + path: PathBuf, + }, + /// Attempt to build a Rust `CString` from a null pointer + #[error("Failed to build CString when original contains null: {0}")] + CStringNulError(#[from] std::ffi::NulError), + #[error("{0} pointer should be null")] + /// Ort Pointer should have been null + PointerShouldBeNull(String), + /// Ort pointer should not have been null + #[error("{0} pointer should not be null")] + PointerShouldNotBeNull(String), + /// ONNXRuntime Model has invalid dimensions + #[error("Invalid dimensions")] + InvalidDimensions, + /// The runtime type was undefined + #[error("Undefined Tensor Element Type")] + UndefinedTensorElementType, + /// Error occurred when checking if ONNXRuntime tensor was properly initialized + #[error("Failed to check if tensor")] + IsTensorCheck, +} + +/// Error used when dimensions of input (from model and from inference call) +/// do not match (as they should). +#[non_exhaustive] +#[derive(Error, Debug)] +pub enum NonMatchingDimensionsError { + /// Number of inputs from model does not match number of inputs from inference call + #[error("Non-matching number of inputs: {inference_input_count:?} for input vs {model_input_count:?} for model (inputs: {inference_input:?}, model: {model_input:?})")] + InputsCount { + /// Number of input dimensions used by inference call + inference_input_count: usize, + /// Number of input dimensions defined in model + model_input_count: usize, + /// Input dimensions used by inference call + inference_input: Vec>, + /// Input dimensions defined in model + model_input: Vec>>, + }, + /// Inputs length from model does not match the expected input from inference call + #[error("Different input lengths: Expected Input: {model_input:?} vs Received Input: {inference_input:?}")] + InputsLength { + /// Input dimensions used by inference call + inference_input: Vec>, + /// Input dimensions defined in model + model_input: Vec>>, + }, +} + +/// Error details when ONNXRuntime C API fail +#[non_exhaustive] +#[derive(Error, Debug)] +pub enum OrtApiError { + /// Details as reported by the ONNXRuntime C API in case of error + #[error("Error calling ONNX Runtime C function: {0}")] + Msg(String), + /// Details as reported by the ONNXRuntime C API in case of error cannot be converted to UTF-8 + #[error("Error calling ONNX Runtime C function and failed to convert error message to UTF-8")] + IntoStringError(std::ffi::IntoStringError), +} + +/// Error from downloading pre-trained model from the [ONNX Model Zoo](https://github.com/onnx/models). +#[non_exhaustive] +#[derive(Error, Debug)] +pub enum OrtDownloadError { + /// Generic input/output error + #[error("Error downloading data to file: {0}")] + IoError(#[from] io::Error), + #[cfg(feature = "model-fetching")] + /// Download error by ureq + #[error("Error downloading data to file: {0}")] + UreqError(#[from] Box), + /// Error getting content-length from an HTTP GET request + #[error("Error getting content-length")] + ContentLengthError, + /// Mismatch between amount of downloaded and expected bytes + #[error("Error copying data to file: expected {expected} length, received {io}")] + CopyError { + /// Expected amount of bytes to download + expected: u64, + /// Number of bytes read from network and written to file + io: u64, + }, +} + +/// Wrapper type around a ONNXRuntime C API's `OrtStatus` pointer +/// +/// This wrapper exists to facilitate conversion from C raw pointers to Rust error types +pub struct OrtStatusWrapper(*const sys::OrtStatus); + +impl From<*const sys::OrtStatus> for OrtStatusWrapper { + fn from(status: *const sys::OrtStatus) -> Self { + OrtStatusWrapper(status) + } +} + +pub(crate) fn assert_null_pointer(ptr: *const T, name: &str) -> Result<()> { + ptr.is_null() + .then_some(()) + .ok_or_else(|| OrtError::PointerShouldBeNull(name.to_owned())) +} + +pub(crate) fn assert_not_null_pointer(ptr: *const T, name: &str) -> Result<()> { + (!ptr.is_null()) + .then_some(()) + .ok_or_else(|| OrtError::PointerShouldBeNull(name.to_owned())) +} + +impl From for std::result::Result<(), OrtApiError> { + fn from(status: OrtStatusWrapper) -> Self { + if status.0.is_null() { + Ok(()) + } else { + let raw: *const i8 = unsafe { + ENV.get() + .unwrap() + .lock() + .unwrap() + .api() + .GetErrorMessage + .unwrap()(status.0) + }; + match char_p_to_string(raw) { + Ok(msg) => Err(OrtApiError::Msg(msg)), + Err(err) => match err { + OrtError::StringConversion(OrtApiError::IntoStringError(e)) => { + Err(OrtApiError::IntoStringError(e)) + } + _ => unreachable!(), + }, + } + } + } +} + +pub(crate) fn status_to_result( + status: *const sys::OrtStatus, +) -> std::result::Result<(), OrtApiError> { + let status_wrapper: OrtStatusWrapper = status.into(); + status_wrapper.into() +} + +/// A wrapper around a function on `OrtApi` that maps the status code into [`OrtApiError`] +pub(crate) unsafe fn call_ort(mut f: F) -> std::result::Result<(), OrtApiError> +where + F: FnMut(sys::OrtApi) -> *const sys::OrtStatus, +{ + status_to_result(f(ENV.get().unwrap().lock().unwrap().api())) +} diff --git a/rust/onnxruntime/src/lib.rs b/rust/onnxruntime/src/lib.rs new file mode 100644 index 0000000000000..ce4721ef4240f --- /dev/null +++ b/rust/onnxruntime/src/lib.rs @@ -0,0 +1,560 @@ +#![warn(missing_docs)] + +//! ONNX Runtime +//! +//! This crate is a (safe) wrapper around Microsoft's [ONNX Runtime](https://github.com/microsoft/onnxruntime/) +//! through its C API. +//! +//! From its [GitHub page](https://github.com/microsoft/onnxruntime/): +//! +//! > ONNX Runtime is a cross-platform, high performance ML inferencing and training accelerator. +//! +//! The (highly) unsafe [C API](https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/onnxruntime_c_api.h) +//! is wrapped using bindgen as [`onnxruntime-sys`](https://crates.io/crates/onnxruntime-sys). +//! +//! The unsafe bindings are wrapped in this crate to expose a safe API. +//! +//! For now, efforts are concentrated on the inference API. Training is _not_ supported. +//! +//! # Example +//! +//! The C++ example that uses the C API +//! ([`C_Api_Sample.cpp`](https://github.com/microsoft/onnxruntime/blob/v1.3.1/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp)) +//! was ported to +//! [`onnxruntime`](https://github.com/nbigaouette/onnxruntime-rs/blob/main/onnxruntime/examples/sample.rs). +//! +//! First, an environment must be created using and [`EnvBuilder`](environment/struct.EnvBuilder.html): +//! +//! ```no_run +//! # use std::error::Error; +//! # use std::env::var; +//! # use onnxruntime::{environment::Environment, LoggingLevel}; +//! # fn main() -> Result<(), Box> { +//! # let path = var("RUST_ONNXRUNTIME_LIBRARY_PATH").ok(); +//! +//! let builder = Environment::builder() +//! .with_name("test") +//! .with_log_level(LoggingLevel::Warning); +//! +//! let builder = if let Some(path) = path { +//! builder.with_library_path(path) +//! } else { +//! builder +//! }; +//! let environment = builder.build()?; +//! Ok(()) +//! } +//! ``` +//! +//! Then a [`Session`](session/struct.Session.html) is created from the environment, some options and an ONNX model file: +//! +//! ```no_run +//! # use std::error::Error; +//! # use std::env::var; +//! # use onnxruntime::{environment::Environment, LoggingLevel, GraphOptimizationLevel}; +//! # fn main() -> Result<(), Box> { +//! # let path = var("RUST_ONNXRUNTIME_LIBRARY_PATH").ok(); +//! # +//! # let builder = Environment::builder() +//! # .with_name("test") +//! # .with_log_level(LoggingLevel::Warning); +//! # +//! # let builder = if let Some(path) = path { +//! # builder.with_library_path(path) +//! # } else { +//! # builder +//! # }; +//! # let environment = builder.build()?; +//! let mut session = environment +//! .new_session_builder()? +//! .with_graph_optimization_level(GraphOptimizationLevel::Basic)? +//! .with_intra_op_num_threads(1)? +//! .with_model_from_file("squeezenet.onnx")?; +//! # Ok(()) +//! # } +//! ``` +#![cfg_attr( + feature = "model-fetching", + doc = r##" +Instead of loading a model from file using [`with_model_from_file()`](session/struct.SessionBuilder.html#method.with_model_from_file), +a model can be fetched directly from the [ONNX Model Zoo](https://github.com/onnx/models) using +[`with_model_downloaded()`](session/struct.SessionBuilder.html#method.with_model_downloaded) method +(requires the `model-fetching` feature). + +```no_run +# use std::error::Error; +# use std::env::var; +# use onnxruntime::{environment::Environment, download::vision::ImageClassification, LoggingLevel, GraphOptimizationLevel}; +# fn main() -> Result<(), Box> { +# let path = var("RUST_ONNXRUNTIME_LIBRARY_PATH").ok(); +# +# let builder = Environment::builder() +# .with_name("test") +# .with_log_level(LoggingLevel::Warning); +# +# let builder = if let Some(path) = path { +# builder.with_library_path(path) +# } else { +# builder +# }; +# let environment = builder.build()?; + +let mut session = environment + .new_session_builder()? + .with_graph_optimization_level(GraphOptimizationLevel::Basic)? + .with_intra_op_num_threads(1)? + .with_model_downloaded(ImageClassification::SqueezeNet)?; +# Ok(()) +# } +``` + +See [`AvailableOnnxModel`](download/enum.AvailableOnnxModel.html) for the different models available +to download. +"## +)] +//! +//! Inference will be run on data passed as an [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html). +//! +//! ```no_run +//! # use std::error::Error; +//! # use std::env::var; +//! # use onnxruntime::{environment::Environment, LoggingLevel, GraphOptimizationLevel, tensor::construct::ConstructTensor}; +//! # fn main() -> Result<(), Box> { +//! # let path = var("RUST_ONNXRUNTIME_LIBRARY_PATH").ok(); +//! # +//! # let builder = Environment::builder() +//! # .with_name("test") +//! # .with_log_level(LoggingLevel::Warning); +//! # +//! # let builder = if let Some(path) = path { +//! # builder.with_library_path(path) +//! # } else { +//! # builder +//! # }; +//! # let environment = builder.build()?; +//! # let mut session = environment +//! # .new_session_builder()? +//! # .with_graph_optimization_level(GraphOptimizationLevel::Basic)? +//! # .with_intra_op_num_threads(1)? +//! # .with_model_from_file("squeezenet.onnx")?; +//! let array = ndarray::Array::linspace(0.0_f32, 1.0, 100); +//! // Multiple inputs and outputs are possible +//! let input_tensor = vec![array.into()]; +//! let outputs = session.run(input_tensor)?; +//! # Ok(()) +//! # } +//! ``` +//! +//! The outputs are of type [`OrtOwnedTensor`](tensor/ort_owned_tensor/struct.OrtOwnedTensor.html)s inside a vector, +//! with the same length as the inputs. +//! +//! See the [`sample.rs`](https://github.com/nbigaouette/onnxruntime-rs/blob/main/onnxruntime/examples/sample.rs) +//! example for more details. + +use onnxruntime_sys as sys; + +// Make functions `extern "stdcall"` for Windows 32bit. +// This behaviors like `extern "system"`. +#[cfg(all(target_os = "windows", target_arch = "x86"))] +macro_rules! extern_system_fn { + ($(#[$meta:meta])* fn $($tt:tt)*) => ($(#[$meta])* extern "stdcall" fn $($tt)*); + ($(#[$meta:meta])* $vis:vis fn $($tt:tt)*) => ($(#[$meta])* $vis extern "stdcall" fn $($tt)*); + ($(#[$meta:meta])* unsafe fn $($tt:tt)*) => ($(#[$meta])* unsafe extern "stdcall" fn $($tt)*); + ($(#[$meta:meta])* $vis:vis unsafe fn $($tt:tt)*) => ($(#[$meta])* $vis unsafe extern "stdcall" fn $($tt)*); +} + +// Make functions `extern "C"` for normal targets. +// This behaviors like `extern "system"`. +#[cfg(not(all(target_os = "windows", target_arch = "x86")))] +macro_rules! extern_system_fn { + ($(#[$meta:meta])* fn $($tt:tt)*) => ($(#[$meta])* extern "C" fn $($tt)*); + ($(#[$meta:meta])* $vis:vis fn $($tt:tt)*) => ($(#[$meta])* $vis extern "C" fn $($tt)*); + ($(#[$meta:meta])* unsafe fn $($tt:tt)*) => ($(#[$meta])* unsafe extern "C" fn $($tt)*); + ($(#[$meta:meta])* $vis:vis unsafe fn $($tt:tt)*) => ($(#[$meta])* $vis unsafe extern "C" fn $($tt)*); +} + +pub mod download; +pub mod environment; +pub mod error; +mod memory; +pub mod session; +pub mod tensor; + +// Re-export +pub use error::{OrtApiError, OrtError, Result}; +use sys::OnnxEnumInt; + +// Re-export ndarray as it's part of the public API anyway +pub use ndarray; + +fn char_p_to_string(raw: *const i8) -> Result { + let c_string = unsafe { std::ffi::CStr::from_ptr(raw as *mut i8).to_owned() }; + + match c_string.into_string() { + Ok(string) => Ok(string), + Err(e) => Err(OrtApiError::IntoStringError(e)), + } + .map_err(OrtError::StringConversion) +} + +mod onnxruntime { + //! Module containing a custom logger, used to catch the runtime's own logging and send it + //! to Rust's tracing logging instead. + + use std::ffi::CStr; + use tracing::{debug, error, info, span, trace, warn, Level}; + + use onnxruntime_sys as sys; + + /// Runtime's logging sends the code location where the log happened, will be parsed to this struct. + #[derive(Debug)] + struct CodeLocation<'a> { + file: &'a str, + line_number: &'a str, + function: &'a str, + } + + impl<'a> From<&'a str> for CodeLocation<'a> { + fn from(code_location: &'a str) -> Self { + let mut splitter = code_location.split(' '); + let file_and_line_number = splitter.next().unwrap_or(""); + let function = splitter.next().unwrap_or(""); + let mut file_and_line_number_splitter = file_and_line_number.split(':'); + let file = file_and_line_number_splitter + .next() + .unwrap_or(""); + let line_number = file_and_line_number_splitter + .next() + .unwrap_or(""); + + CodeLocation { + file, + line_number, + function, + } + } + } + + extern_system_fn! { + /// Callback from C that will handle the logging, forwarding the runtime's logs to the tracing crate. + pub(crate) fn custom_logger( + _params: *mut std::ffi::c_void, + severity: sys::OrtLoggingLevel, + category: *const i8, + logid: *const i8, + code_location: *const i8, + message: *const i8, + ) { + let log_level = match severity { + sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE => Level::TRACE, + sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO => Level::DEBUG, + sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING => Level::INFO, + sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR => Level::WARN, + sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL => Level::ERROR, + }; + + assert_ne!(category, std::ptr::null()); + let category = unsafe { CStr::from_ptr(category) }; + assert_ne!(code_location, std::ptr::null()); + let code_location = unsafe { CStr::from_ptr(code_location) } + .to_str() + .unwrap_or("unknown"); + assert_ne!(message, std::ptr::null()); + let message = unsafe { CStr::from_ptr(message) }; + + assert_ne!(logid, std::ptr::null()); + let logid = unsafe { CStr::from_ptr(logid) }; + + // Parse the code location + let code_location: CodeLocation = code_location.into(); + + let span = span!( + Level::TRACE, + "onnxruntime", + category = category.to_str().unwrap_or(""), + file = code_location.file, + line_number = code_location.line_number, + function = code_location.function, + logid = logid.to_str().unwrap_or(""), + ); + let _enter = span.enter(); + + match log_level { + Level::TRACE => trace!("{:?}", message), + Level::DEBUG => debug!("{:?}", message), + Level::INFO => info!("{:?}", message), + Level::WARN => warn!("{:?}", message), + Level::ERROR => error!("{:?}", message), + } + } + } +} + +/// Logging level of the ONNX Runtime C API +#[derive(Debug, Clone, Copy)] +#[cfg_attr(not(windows), repr(u32))] +#[cfg_attr(windows, repr(i32))] +pub enum LoggingLevel { + /// Verbose log level + Verbose = sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE as OnnxEnumInt, + /// Info log level + Info = sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO as OnnxEnumInt, + /// Warning log level + Warning = sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING as OnnxEnumInt, + /// Error log level + Error = sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR as OnnxEnumInt, + /// Fatal log level + Fatal = sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL as OnnxEnumInt, +} + +impl From for sys::OrtLoggingLevel { + fn from(val: LoggingLevel) -> Self { + match val { + LoggingLevel::Verbose => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + LoggingLevel::Info => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + LoggingLevel::Warning => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + LoggingLevel::Error => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR, + LoggingLevel::Fatal => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL, + } + } +} + +/// Optimization level performed by ONNX Runtime of the loaded graph +/// +/// See the [official documentation](https://github.com/microsoft/onnxruntime/blob/main/docs/ONNX_Runtime_Graph_Optimizations.md) +/// for more information on the different optimization levels. +#[derive(Debug)] +#[cfg_attr(not(windows), repr(u32))] +#[cfg_attr(windows, repr(i32))] +pub enum GraphOptimizationLevel { + /// Disable optimization + DisableAll = sys::GraphOptimizationLevel::ORT_DISABLE_ALL as OnnxEnumInt, + /// Basic optimization + Basic = sys::GraphOptimizationLevel::ORT_ENABLE_BASIC as OnnxEnumInt, + /// Extended optimization + Extended = sys::GraphOptimizationLevel::ORT_ENABLE_EXTENDED as OnnxEnumInt, + /// Add optimization + All = sys::GraphOptimizationLevel::ORT_ENABLE_ALL as OnnxEnumInt, +} + +impl From for sys::GraphOptimizationLevel { + fn from(val: GraphOptimizationLevel) -> Self { + use GraphOptimizationLevel::{All, Basic, DisableAll, Extended}; + match val { + DisableAll => sys::GraphOptimizationLevel::ORT_DISABLE_ALL, + Basic => sys::GraphOptimizationLevel::ORT_ENABLE_BASIC, + Extended => sys::GraphOptimizationLevel::ORT_ENABLE_EXTENDED, + All => sys::GraphOptimizationLevel::ORT_ENABLE_ALL, + } + } +} + +// FIXME: Use https://docs.rs/bindgen/0.54.1/bindgen/struct.Builder.html#method.rustified_enum +// FIXME: Add tests to cover the commented out types +/// Enum mapping ONNX Runtime's supported tensor types +#[derive(Debug)] +#[cfg_attr(not(windows), repr(u32))] +#[cfg_attr(windows, repr(i32))] +pub enum TensorElementDataType { + /// 32-bit floating point, equivalent to Rust's `f32` + Float = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT as OnnxEnumInt, + /// Unsigned 8-bit int, equivalent to Rust's `u8` + Uint8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 as OnnxEnumInt, + /// Signed 8-bit int, equivalent to Rust's `i8` + Int8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 as OnnxEnumInt, + /// Unsigned 16-bit int, equivalent to Rust's `u16` + Uint16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 as OnnxEnumInt, + /// Signed 16-bit int, equivalent to Rust's `i16` + Int16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 as OnnxEnumInt, + /// Signed 32-bit int, equivalent to Rust's `i32` + Int32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 as OnnxEnumInt, + /// Signed 64-bit int, equivalent to Rust's `i64` + Int64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 as OnnxEnumInt, + /// String, equivalent to Rust's `String` + String = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt, + // /// Boolean, equivalent to Rust's `bool` + // Bool = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL as OnnxEnumInt, + // /// 16-bit floating point, equivalent to Rust's `f16` + // Float16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 as OnnxEnumInt, + /// 64-bit floating point, equivalent to Rust's `f64` + Double = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE as OnnxEnumInt, + /// Unsigned 32-bit int, equivalent to Rust's `u32` + Uint32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 as OnnxEnumInt, + /// Unsigned 64-bit int, equivalent to Rust's `u64` + Uint64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 as OnnxEnumInt, + // /// Complex 64-bit floating point, equivalent to Rust's `???` + // Complex64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 as OnnxEnumInt, + // /// Complex 128-bit floating point, equivalent to Rust's `???` + // Complex128 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 as OnnxEnumInt, + // /// Brain 16-bit floating point + // Bfloat16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 as OnnxEnumInt, +} + +impl From for sys::ONNXTensorElementDataType { + fn from(val: TensorElementDataType) -> Self { + use TensorElementDataType::{ + Double, Float, Int16, Int32, Int64, Int8, String, Uint16, Uint32, Uint64, Uint8, + }; + match val { + Float => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + Uint8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, + Int8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, + Uint16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, + Int16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, + Int32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, + Int64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, + String => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, + // Bool => { + // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL + // } + // Float16 => { + // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 + // } + Double => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, + Uint32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, + Uint64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, + // Complex64 => { + // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 + // } + // Complex128 => { + // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 + // } + // Bfloat16 => { + // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 + // } + } + } +} + +/// Trait used to map Rust types (for example `f32`) to ONNX types (for example `Float`) +pub trait TypeToTensorElementDataType { + /// Return the ONNX type for a Rust type + fn tensor_element_data_type() -> TensorElementDataType; + + /// If the type is `String`, returns `Some` with utf8 contents, else `None`. + fn try_utf8_bytes(&self) -> Option<&[u8]>; +} + +macro_rules! impl_type_trait { + ($type_:ty, $variant:ident) => { + impl TypeToTensorElementDataType for $type_ { + fn tensor_element_data_type() -> TensorElementDataType { + // unsafe { std::mem::transmute(TensorElementDataType::$variant) } + TensorElementDataType::$variant + } + + fn try_utf8_bytes(&self) -> Option<&[u8]> { + None + } + } + }; +} + +impl_type_trait!(f32, Float); +impl_type_trait!(u8, Uint8); +impl_type_trait!(i8, Int8); +impl_type_trait!(u16, Uint16); +impl_type_trait!(i16, Int16); +impl_type_trait!(i32, Int32); +impl_type_trait!(i64, Int64); +// impl_type_trait!(bool, Bool); +// impl_type_trait!(f16, Float16); +impl_type_trait!(f64, Double); +impl_type_trait!(u32, Uint32); +impl_type_trait!(u64, Uint64); +// impl_type_trait!(, Complex64); +// impl_type_trait!(, Complex128); +// impl_type_trait!(, Bfloat16); + +/// Adapter for common Rust string types to Onnx strings. +/// +/// It should be easy to use both `String` and `&str` as [`TensorElementDataType::String`] data, but +/// we can't define an automatic implementation for anything that implements `AsRef` as it +/// would conflict with the implementations of [`TypeToTensorElementDataType`] for primitive numeric +/// types (which might implement `AsRef` at some point in the future). +pub trait Utf8Data { + /// Returns the utf8 contents. + fn utf8_bytes(&self) -> &[u8]; +} + +impl Utf8Data for String { + fn utf8_bytes(&self) -> &[u8] { + self.as_bytes() + } +} + +impl<'a> Utf8Data for &'a str { + fn utf8_bytes(&self) -> &[u8] { + self.as_bytes() + } +} + +impl TypeToTensorElementDataType for T { + fn tensor_element_data_type() -> TensorElementDataType { + TensorElementDataType::String + } + + fn try_utf8_bytes(&self) -> Option<&[u8]> { + Some(self.utf8_bytes()) + } +} + +/// Allocator type +#[derive(Debug, Clone)] +#[repr(i32)] +pub enum AllocatorType { + // Invalid = sys::OrtAllocatorType::Invalid as i32, + /// Device allocator + Device = sys::OrtAllocatorType::OrtDeviceAllocator as i32, + /// Arena allocator + Arena = sys::OrtAllocatorType::OrtArenaAllocator as i32, +} + +impl From for sys::OrtAllocatorType { + fn from(val: AllocatorType) -> Self { + use AllocatorType::{Arena, Device}; + match val { + // Invalid => sys::OrtAllocatorType::Invalid, + Device => sys::OrtAllocatorType::OrtDeviceAllocator, + Arena => sys::OrtAllocatorType::OrtArenaAllocator, + } + } +} + +/// Memory type +/// +/// Only support ONNX's default type for now. +#[derive(Debug, Clone)] +#[repr(i32)] +pub enum MemType { + // FIXME: C API's `OrtMemType_OrtMemTypeCPU` defines it equal to `OrtMemType_OrtMemTypeCPUOutput`. How to handle this?? + // CPUInput = sys::OrtMemType::OrtMemTypeCPUInput as i32, + // CPUOutput = sys::OrtMemType::OrtMemTypeCPUOutput as i32, + // CPU = sys::OrtMemType::OrtMemTypeCPU as i32, + /// Default memory type + Default = sys::OrtMemType::OrtMemTypeDefault as i32, +} + +impl From for sys::OrtMemType { + fn from(val: MemType) -> Self { + use MemType::Default; + match val { + // CPUInput => sys::OrtMemType::OrtMemTypeCPUInput, + // CPUOutput => sys::OrtMemType::OrtMemTypeCPUOutput, + // CPU => sys::OrtMemType::OrtMemTypeCPU, + Default => sys::OrtMemType::OrtMemTypeDefault, + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_char_p_to_string() { + let s = std::ffi::CString::new("foo").unwrap(); + let ptr = s.as_c_str().as_ptr(); + assert_eq!("foo", char_p_to_string(ptr).unwrap()); + } +} diff --git a/rust/onnxruntime/src/memory.rs b/rust/onnxruntime/src/memory.rs new file mode 100644 index 0000000000000..1688d433fe276 --- /dev/null +++ b/rust/onnxruntime/src/memory.rs @@ -0,0 +1,81 @@ +use tracing::debug; + +use onnxruntime_sys as sys; + +use crate::{ + environment::{Environment, _Environment}, + error::{assert_not_null_pointer, status_to_result, OrtError, Result}, + AllocatorType, MemType, +}; + +use tracing::error; + +#[derive(Debug)] +pub struct MemoryInfo { + pub ptr: *mut sys::OrtMemoryInfo, + env: _Environment, +} + +impl MemoryInfo { + #[tracing::instrument] + pub fn new(allocator: AllocatorType, memory_type: MemType, env: &Environment) -> Result { + debug!("Creating new memory info."); + let mut memory_info_ptr: *mut sys::OrtMemoryInfo = std::ptr::null_mut(); + let status = unsafe { + env.env().api().CreateCpuMemoryInfo.unwrap()( + allocator.into(), + memory_type.into(), + &mut memory_info_ptr, + ) + }; + status_to_result(status).map_err(OrtError::CreateCpuMemoryInfo)?; + assert_not_null_pointer(memory_info_ptr, "MemoryInfo")?; + + Ok(Self { + ptr: memory_info_ptr, + env: env.env.clone(), + }) + } +} + +impl Drop for MemoryInfo { + #[tracing::instrument] + fn drop(&mut self) { + if self.ptr.is_null() { + error!("MemoryInfo pointer is null, not dropping."); + } else { + debug!("Dropping the memory information."); + unsafe { self.env.env().api().ReleaseMemoryInfo.unwrap()(self.ptr) }; + } + + self.ptr = std::ptr::null_mut(); + } +} + +#[cfg(test)] +mod tests { + use std::env::var; + + use super::*; + use crate::{environment::tests::ONNX_RUNTIME_LIBRARY_PATH, LoggingLevel}; + use test_log::test; + + #[test] + fn memory_info_constructor_destructor() { + let path = var(ONNX_RUNTIME_LIBRARY_PATH).ok(); + + let builder = Environment::builder() + .with_name("test") + .with_log_level(LoggingLevel::Warning); + + let builder = if let Some(path) = path { + builder.with_library_path(path) + } else { + builder + }; + let env = builder.build().unwrap(); + + let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default, &env).unwrap(); + std::mem::drop(memory_info); + } +} diff --git a/rust/onnxruntime/src/session.rs b/rust/onnxruntime/src/session.rs new file mode 100644 index 0000000000000..326426e35982c --- /dev/null +++ b/rust/onnxruntime/src/session.rs @@ -0,0 +1,806 @@ +//! Module containing session types + +use std::{convert::TryFrom, ffi::CString, fmt::Debug, path::Path}; + +#[cfg(not(target_family = "windows"))] +use std::os::unix::ffi::OsStrExt; +#[cfg(target_family = "windows")] +use std::os::windows::ffi::OsStrExt; + +#[cfg(feature = "model-fetching")] +use std::env; + +use crate::{ + char_p_to_string, + environment::{Environment, _Environment}, + error::{ + assert_not_null_pointer, assert_null_pointer, status_to_result, NonMatchingDimensionsError, + OrtApiError, OrtError, Result, + }, + memory::MemoryInfo, + tensor::{ + construct::ConstructTensor, + ort_output_tensor::{OrtOutput, OrtOwnedTensorExtractor}, + OrtOutputTensor, + }, + AllocatorType, GraphOptimizationLevel, MemType, TensorElementDataType, +}; +use onnxruntime_sys as sys; + +use tracing::{debug, error}; + +#[cfg(feature = "model-fetching")] +use crate::{download::AvailableOnnxModel, error::OrtDownloadError}; + +/// Type used to create a session using the _builder pattern_ +/// +/// A `SessionBuilder` is created by calling the +/// [`Environment::new_session_builder()`](../env/struct.Environment.html#method.new_session_builder) +/// method on the environment. +/// +/// Once created, use the different methods to configure the session. +/// +/// Once configured, use the [`SessionBuilder::with_model_from_file()`](../session/struct.SessionBuilder.html#method.with_model_from_file) +/// method to "commit" the builder configuration into a [`Session`](../session/struct.Session.html). +/// +/// # Example +/// +/// ```no_run +/// # use std::error::Error; +/// # use std::env::var; +/// # use onnxruntime::{environment::Environment, LoggingLevel, GraphOptimizationLevel}; +/// # fn main() -> Result<(), Box> { +/// # let path = var("RUST_ONNXRUNTIME_LIBRARY_PATH").ok(); +/// +/// let builder = Environment::builder() +/// .with_name("test") +/// .with_log_level(LoggingLevel::Warning); +/// +/// let builder = if let Some(path) = path { +/// builder.with_library_path(path) +/// } else { +/// builder +/// }; +/// let environment = builder.build()?; +/// +/// let mut session = environment +/// .new_session_builder()? +/// .with_graph_optimization_level(GraphOptimizationLevel::Basic)? +/// .with_intra_op_num_threads(1)? +/// .with_model_from_file("squeezenet.onnx")?; +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug)] +pub struct SessionBuilder<'a> { + env: &'a Environment, + session_options_ptr: *mut sys::OrtSessionOptions, + + allocator: AllocatorType, + memory_type: MemType, +} + +impl<'a> Drop for SessionBuilder<'a> { + #[tracing::instrument] + fn drop(&mut self) { + if self.session_options_ptr.is_null() { + error!("Session options pointer is null, not dropping"); + } else { + debug!("Dropping the session options."); + unsafe { + self.env.env().api().ReleaseSessionOptions.unwrap()(self.session_options_ptr) + }; + } + } +} + +impl<'a> SessionBuilder<'a> { + pub(crate) fn new(env: &'a Environment) -> Result> { + let mut session_options_ptr: *mut sys::OrtSessionOptions = std::ptr::null_mut(); + let status = + unsafe { env.env().api().CreateSessionOptions.unwrap()(&mut session_options_ptr) }; + + status_to_result(status).map_err(OrtError::SessionOptions)?; + assert_null_pointer(status, "SessionStatus")?; + assert_not_null_pointer(session_options_ptr, "SessionOptions")?; + + Ok(SessionBuilder { + env, + session_options_ptr, + allocator: AllocatorType::Arena, + memory_type: MemType::Default, + }) + } + + /// Configure the session to use a number of threads + pub fn with_intra_op_num_threads(self, num_threads: i16) -> Result> { + // FIXME: Pre-built binaries use OpenMP, set env variable instead + + // We use a u16 in the builder to cover the 16-bits positive values of a i32. + let num_threads = i32::from(num_threads); + let status = unsafe { + self.env.env().api().SetIntraOpNumThreads.unwrap()( + self.session_options_ptr, + num_threads, + ) + }; + status_to_result(status).map_err(OrtError::SessionOptions)?; + assert_null_pointer(status, "SessionStatus")?; + Ok(self) + } + + /// Set the session's optimization level + pub fn with_graph_optimization_level( + self, + opt_level: GraphOptimizationLevel, + ) -> Result> { + // Sets graph optimization level + unsafe { + self.env + .env() + .api() + .SetSessionGraphOptimizationLevel + .unwrap()(self.session_options_ptr, opt_level.into()) + }; + Ok(self) + } + + /// Set the session's allocator + /// + /// Defaults to [`AllocatorType::Arena`](../enum.AllocatorType.html#variant.Arena) + pub fn with_allocator(mut self, allocator: AllocatorType) -> Result> { + self.allocator = allocator; + Ok(self) + } + + /// Set the session's memory type + /// + /// Defaults to [`MemType::Default`](../enum.MemType.html#variant.Default) + pub fn with_memory_type(mut self, memory_type: MemType) -> Result> { + self.memory_type = memory_type; + Ok(self) + } + + /// Download an ONNX pre-trained model from the [ONNX Model Zoo](https://github.com/onnx/models) and commit the session + #[cfg(feature = "model-fetching")] + pub fn with_model_downloaded(self, model: M) -> Result + where + M: Into, + { + self.with_model_downloaded_monomorphized(model.into()) + } + + #[cfg(feature = "model-fetching")] + fn with_model_downloaded_monomorphized(self, model: AvailableOnnxModel) -> Result { + let download_dir = env::current_dir().map_err(OrtDownloadError::IoError)?; + let downloaded_path = model.download_to(download_dir)?; + self.with_model_from_file(downloaded_path) + } + + // TODO: Add all functions changing the options. + // See all OrtApi methods taking a `options: *mut OrtSessionOptions`. + + /// Load an ONNX graph from a file and commit the session + pub fn with_model_from_file

(self, model_filepath_ref: P) -> Result + where + P: AsRef + 'a, + { + let model_filepath = model_filepath_ref.as_ref(); + let mut session_ptr: *mut sys::OrtSession = std::ptr::null_mut(); + + if !model_filepath.exists() { + return Err(OrtError::FileDoesNotExists { + filename: model_filepath.to_path_buf(), + }); + } + + // Build an OsString than a vector of bytes to pass to C + let model_path = std::ffi::OsString::from(model_filepath); + #[cfg(target_family = "windows")] + let model_path: Vec = model_path + .encode_wide() + .chain(std::iter::once(0)) // Make sure we have a null terminated string + .collect(); + #[cfg(not(target_family = "windows"))] + let model_path: Vec = model_path + .as_bytes() + .iter() + .chain(std::iter::once(&b'\0')) // Make sure we have a null terminated string + .map(|b| *b as std::os::raw::c_char) + .collect(); + + unsafe { + let api = self.env.env().api(); + + let status = api.CreateSession.unwrap()( + self.env.env().env_ptr, + model_path.as_ptr(), + self.session_options_ptr, + &mut session_ptr, + ); + + status_to_result(status).map_err(OrtError::Session)?; + assert_null_pointer(status, "SessionStatus")?; + assert_not_null_pointer(session_ptr, "Session")?; + }; + let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut(); + let status = unsafe { + self.env.env().api().GetAllocatorWithDefaultOptions.unwrap()(&mut allocator_ptr) + }; + status_to_result(status).map_err(OrtError::Allocator)?; + assert_null_pointer(status, "SessionStatus")?; + assert_not_null_pointer(allocator_ptr, "Allocator")?; + + let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default, &self.env)?; + unsafe { + // Extract input and output properties + let num_input_nodes = + dangerous::extract_inputs_count(session_ptr, self.env.env.clone())?; + let num_output_nodes = + dangerous::extract_outputs_count(session_ptr, self.env.env.clone())?; + let inputs = (0..num_input_nodes) + .map(|i| { + dangerous::extract_input(session_ptr, allocator_ptr, i, self.env.env.clone()) + }) + .collect::>>()?; + let outputs = (0..num_output_nodes) + .map(|i| { + dangerous::extract_output(session_ptr, allocator_ptr, i, self.env.env.clone()) + }) + .collect::>>()?; + + Ok(Session { + env: self.env.env.clone(), + session_ptr, + allocator_ptr, + memory_info, + inputs, + outputs, + }) + } + } + + /// Load an ONNX graph from memory and commit the session + pub fn with_model_from_memory(self, model_bytes: B) -> Result + where + B: AsRef<[u8]>, + { + self.with_model_from_memory_monomorphized(model_bytes.as_ref()) + } + + fn with_model_from_memory_monomorphized(self, model_bytes: &[u8]) -> Result { + let mut session_ptr: *mut sys::OrtSession = std::ptr::null_mut(); + unsafe { + let api = self.env.env().api(); + + let model_data = model_bytes.as_ptr().cast::(); + let model_data_length = model_bytes.len(); + let status = api.CreateSessionFromArray.unwrap()( + self.env.env().env_ptr, + model_data, + model_data_length, + self.session_options_ptr, + &mut session_ptr, + ); + + status_to_result(status).map_err(OrtError::Session)?; + assert_null_pointer(status, "SessionStatus")?; + assert_not_null_pointer(session_ptr, "Session")?; + }; + let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut(); + let status = unsafe { + self.env.env().api().GetAllocatorWithDefaultOptions.unwrap()(&mut allocator_ptr) + }; + status_to_result(status).map_err(OrtError::Allocator)?; + assert_null_pointer(status, "SessionStatus")?; + assert_not_null_pointer(allocator_ptr, "Allocator")?; + + let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default, &self.env)?; + unsafe { + // Extract input and output properties + let num_input_nodes = + dangerous::extract_inputs_count(session_ptr, self.env.env.clone())?; + let num_output_nodes = + dangerous::extract_outputs_count(session_ptr, self.env.env.clone())?; + let inputs = (0..num_input_nodes) + .map(|i| { + dangerous::extract_input(session_ptr, allocator_ptr, i, self.env.env.clone()) + }) + .collect::>>()?; + let outputs = (0..num_output_nodes) + .map(|i| { + dangerous::extract_output(session_ptr, allocator_ptr, i, self.env.env.clone()) + }) + .collect::>>()?; + + Ok(Session { + env: self.env.env.clone(), + session_ptr, + allocator_ptr, + memory_info, + inputs, + outputs, + }) + } + } +} + +/// Type storing the session information, built from an [`Environment`](environment/struct.Environment.html) +#[derive(Debug)] +pub struct Session { + env: _Environment, + session_ptr: *mut sys::OrtSession, + allocator_ptr: *mut sys::OrtAllocator, + memory_info: MemoryInfo, + /// Information about the ONNX's inputs as stored in loaded file + pub inputs: Vec, + /// Information about the ONNX's outputs as stored in loaded file + pub outputs: Vec, +} + +/// Information about an ONNX's input as stored in loaded file +#[derive(Debug)] +pub struct Input { + /// Name of the input layer + pub name: String, + /// Type of the input layer's elements + pub input_type: TensorElementDataType, + /// Shape of the input layer + /// + /// C API uses a i64 for the dimensions. We use an unsigned of the same range of the positive values. + pub dimensions: Vec>, +} + +/// Information about an ONNX's output as stored in loaded file +#[derive(Debug)] +pub struct Output { + /// Name of the output layer + pub name: String, + /// Type of the output layer's elements + pub output_type: TensorElementDataType, + /// Shape of the output layer + /// + /// C API uses a i64 for the dimensions. We use an unsigned of the same range of the positive values. + pub dimensions: Vec>, +} + +impl Input { + /// Return an iterator over the shape elements of the input layer + /// + /// Note: The member [`Input::dimensions`](struct.Input.html#structfield.dimensions) + /// stores `u32` (since ONNX uses `i64` but which cannot be negative) so the + /// iterator converts to `usize`. + pub fn dimensions(&self) -> impl Iterator> + '_ { + self.dimensions.iter().map(|d| d.map(|d2| d2 as usize)) + } +} + +impl Output { + /// Return an iterator over the shape elements of the output layer + /// + /// Note: The member [`Output::dimensions`](struct.Output.html#structfield.dimensions) + /// stores `u32` (since ONNX uses `i64` but which cannot be negative) so the + /// iterator converts to `usize`. + pub fn dimensions(&self) -> impl Iterator> + '_ { + self.dimensions.iter().map(|d| d.map(|d2| d2 as usize)) + } +} + +impl Drop for Session { + #[tracing::instrument] + fn drop(&mut self) { + debug!("Dropping the session."); + if self.session_ptr.is_null() { + error!("Session pointer is null, not dropping."); + } else { + unsafe { self.env.env().api().ReleaseSession.unwrap()(self.session_ptr) }; + } + + self.session_ptr = std::ptr::null_mut(); + self.allocator_ptr = std::ptr::null_mut(); + } +} + +unsafe impl Send for Session {} + +unsafe impl Sync for Session {} + +impl Session { + /// Run the input data through the ONNX graph, performing inference. + /// + /// Note that ONNX models can have multiple inputs; a `Vec<_>` is thus + /// used for the input data here. + pub fn run<'input, 'output>( + &'output self, + mut input_arrays: impl AsMut<[Box]> + 'input, + ) -> Result>> { + let mut output_tensor_extractors_ptrs: Vec<*mut sys::OrtValue> = + vec![std::ptr::null_mut(); self.outputs.len()]; + + let output_names_cstring: Vec = self + .outputs + .iter() + .map(|output| output.name.clone()) + .map(|n| CString::new(n).unwrap()) + .collect(); + let output_names_ptr: Vec<*const i8> = output_names_cstring + .iter() + .map(|n| n.as_ptr().cast::()) + .collect(); + + let input_names_ptr: Vec<*const i8> = self + .inputs + .iter() + .map(|input| input.name.clone()) + .map(|n| CString::new(n).unwrap()) + .map(|n| n.into_raw() as *const i8) + .collect(); + + { + let memory_info = &self.memory_info; + + let allocator = self.allocator_ptr; + + let arr = input_arrays.as_mut(); + + let input_tensors = arr + .into_iter() + .map(|v| v.construct(memory_info, allocator)) + .collect::>>()?; + + let input_arrays_shapes: Vec> = + input_tensors.iter().map(|v| v.shape().to_vec()).collect(); + + self.validate_input_shapes(&input_arrays_shapes)?; + + // Build arguments to Run() + + let input_ort_values: Vec<*const sys::OrtValue> = input_tensors + .iter() + .map(|input_array_ort| input_array_ort.ptr() as *const sys::OrtValue) + .collect(); + + let run_options_ptr: *const sys::OrtRunOptions = std::ptr::null(); + + let status = unsafe { + self.env.env().api().Run.unwrap()( + self.session_ptr, + run_options_ptr, + input_names_ptr.as_ptr(), + input_ort_values.as_ptr(), + input_ort_values.len(), + output_names_ptr.as_ptr(), + output_names_ptr.len(), + output_tensor_extractors_ptrs.as_mut_ptr(), + ) + }; + status_to_result(status).map_err(OrtError::Run)?; + } + + let outputs: Result> = output_tensor_extractors_ptrs + .into_iter() + .map(|ptr| { + let mut tensor_info_ptr: *mut sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); + let status = unsafe { + self.env.env().api().GetTensorTypeAndShape.unwrap()( + ptr, + &mut tensor_info_ptr as _, + ) + }; + status_to_result(status).map_err(OrtError::GetTensorTypeAndShape)?; + let dims = unsafe { get_tensor_dimensions(tensor_info_ptr, self.env.clone()) }; + + unsafe { + self.env.env().api().ReleaseTensorTypeAndShapeInfo.unwrap()(tensor_info_ptr) + }; + let dims: Vec<_> = dims?.iter().map(|&n| n as usize).collect(); + + let mut output_tensor_extractor = + OrtOwnedTensorExtractor::new(dims, self.env.clone()); + output_tensor_extractor.tensor_ptr = ptr; + + output_tensor_extractor.extract() + }) + .collect(); + + // Reconvert to CString so drop impl is called and memory is freed + let cstrings: Result> = input_names_ptr + .into_iter() + .map(|p| { + assert_not_null_pointer(p, "i8 for CString")?; + unsafe { Ok(CString::from_raw(p as *mut i8)) } + }) + .collect(); + cstrings?; + + outputs? + .into_iter() + .map(|v| OrtOutput::try_from(v)) + .collect() + } + + fn validate_input_shapes(&self, input_array_shapes: &[Vec]) -> Result<()> { + // ****************************************************************** + // FIXME: Properly handle errors here + // Make sure all dimensions match (except dynamic ones) + + // Verify length of inputs + if input_array_shapes.len() != self.inputs.len() { + error!( + "Non-matching number of inputs: {} (inference) vs {} (model)", + input_array_shapes.len(), + self.inputs.len() + ); + return Err(OrtError::NonMatchingDimensions( + NonMatchingDimensionsError::InputsCount { + inference_input_count: 0, + model_input_count: 0, + inference_input: input_array_shapes.to_vec(), + model_input: self + .inputs + .iter() + .map(|input| input.dimensions.clone()) + .collect(), + }, + )); + } + + // Verify length of each individual inputs + let inputs_different_length = input_array_shapes + .iter() + .zip(self.inputs.iter()) + .any(|(l, r)| l.len() != r.dimensions.len()); + if inputs_different_length { + error!( + "Different input lengths: {:?} vs {:?}", + self.inputs, input_array_shapes + ); + return Err(OrtError::NonMatchingDimensions( + NonMatchingDimensionsError::InputsLength { + inference_input: input_array_shapes + .iter() + .map(|input_array| input_array.to_vec()) + .collect(), + model_input: self + .inputs + .iter() + .map(|input| input.dimensions.clone()) + .collect(), + }, + )); + } + + // Verify shape of each individual inputs + let inputs_different_shape = + input_array_shapes + .iter() + .zip(self.inputs.iter()) + .any(|(l, r)| { + let l_shape = l; + let r_shape = r.dimensions.as_slice(); + l_shape.iter().zip(r_shape.iter()).any(|(l2, r2)| match r2 { + Some(r3) => *r3 as usize != *l2, + None => false, // None means dynamic size; in that case shape always match + }) + }); + if inputs_different_shape { + error!( + "Different input lengths: {:?} vs {:?}", + self.inputs, input_array_shapes + ); + return Err(OrtError::NonMatchingDimensions( + NonMatchingDimensionsError::InputsLength { + inference_input: input_array_shapes + .iter() + .map(|input_array| input_array.to_vec()) + .collect(), + model_input: self + .inputs + .iter() + .map(|input| input.dimensions.clone()) + .collect(), + }, + )); + } + + Ok(()) + } +} + +unsafe fn get_tensor_dimensions( + tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo, + env: _Environment, +) -> Result> { + let mut num_dims = 0; + let status = env.env().api().GetDimensionsCount.unwrap()(tensor_info_ptr, &mut num_dims); + status_to_result(status).map_err(OrtError::GetDimensionsCount)?; + (num_dims != 0) + .then_some(()) + .ok_or(OrtError::InvalidDimensions)?; + + let mut node_dims: Vec = vec![0; num_dims as usize]; + let status = env.env().api().GetDimensions.unwrap()( + tensor_info_ptr, + node_dims.as_mut_ptr(), // FIXME: UB? + num_dims, + ); + status_to_result(status).map_err(OrtError::GetDimensions)?; + Ok(node_dims) +} + +/// This module contains dangerous functions working on raw pointers. +/// Those functions are only to be used from inside the +/// `SessionBuilder::with_model_from_file()` method. +mod dangerous { + use super::{ + assert_not_null_pointer, assert_null_pointer, char_p_to_string, get_tensor_dimensions, + status_to_result, sys, Input, OrtApiError, OrtError, Output, Result, TensorElementDataType, + }; + + use crate::environment::_Environment; + + pub(super) unsafe fn extract_inputs_count( + session_ptr: *mut sys::OrtSession, + env: _Environment, + ) -> Result { + let f = env.env().api().SessionGetInputCount.unwrap(); + extract_io_count(f, session_ptr) + } + + pub(super) unsafe fn extract_outputs_count( + session_ptr: *mut sys::OrtSession, + env: _Environment, + ) -> Result { + let f = env.env().api().SessionGetOutputCount.unwrap(); + extract_io_count(f, session_ptr) + } + + fn extract_io_count( + f: extern_system_fn! { unsafe fn(*const sys::OrtSession, *mut usize) -> *mut sys::OrtStatus }, + session_ptr: *mut sys::OrtSession, + ) -> Result { + let mut num_nodes: usize = 0; + let status = unsafe { f(session_ptr, &mut num_nodes) }; + status_to_result(status).map_err(OrtError::InOutCount)?; + assert_null_pointer(status, "SessionStatus")?; + (num_nodes != 0).then_some(()).ok_or_else(|| { + OrtError::InOutCount(OrtApiError::Msg("No nodes in model".to_owned())) + })?; + Ok(num_nodes) + } + + unsafe fn extract_input_name( + session_ptr: *mut sys::OrtSession, + allocator_ptr: *mut sys::OrtAllocator, + i: usize, + env: _Environment, + ) -> Result { + let f = env.env().api().SessionGetInputName.unwrap(); + extract_io_name(f, session_ptr, allocator_ptr, i, env) + } + + unsafe fn extract_output_name( + session_ptr: *mut sys::OrtSession, + allocator_ptr: *mut sys::OrtAllocator, + i: usize, + env: _Environment, + ) -> Result { + let f = env.env().api().SessionGetOutputName.unwrap(); + extract_io_name(f, session_ptr, allocator_ptr, i, env) + } + + fn extract_io_name( + f: extern_system_fn! { unsafe fn( + *const sys::OrtSession, + usize, + *mut sys::OrtAllocator, + *mut *mut i8, + ) -> *mut sys::OrtStatus }, + session_ptr: *mut sys::OrtSession, + allocator_ptr: *mut sys::OrtAllocator, + i: usize, + env: _Environment, + ) -> Result { + let mut name_bytes: *mut i8 = std::ptr::null_mut(); + + let status = unsafe { f(session_ptr, i, allocator_ptr, &mut name_bytes) }; + status_to_result(status).map_err(OrtError::InputName)?; + assert_not_null_pointer(name_bytes, "InputName")?; + + let name = char_p_to_string(name_bytes)?; + + unsafe { + env.env().api().AllocatorFree.unwrap()( + allocator_ptr, + name_bytes as *mut std::ffi::c_void, + ) + }; + + Ok(name) + } + + pub(super) unsafe fn extract_input( + session_ptr: *mut sys::OrtSession, + allocator_ptr: *mut sys::OrtAllocator, + i: usize, + env: _Environment, + ) -> Result { + let input_name = extract_input_name(session_ptr, allocator_ptr, i, env.clone())?; + let f = env.env().api().SessionGetInputTypeInfo.unwrap(); + let (input_type, dimensions) = extract_io(f, session_ptr, i, env)?; + Ok(Input { + name: input_name, + input_type, + dimensions, + }) + } + + pub(super) unsafe fn extract_output( + session_ptr: *mut sys::OrtSession, + allocator_ptr: *mut sys::OrtAllocator, + i: usize, + env: _Environment, + ) -> Result { + let output_name = extract_output_name(session_ptr, allocator_ptr, i, env.clone())?; + let f = env.env().api().SessionGetOutputTypeInfo.unwrap(); + let (output_type, dimensions) = extract_io(f, session_ptr, i, env)?; + Ok(Output { + name: output_name, + output_type, + dimensions, + }) + } + + fn extract_io( + f: extern_system_fn! { unsafe fn( + *const sys::OrtSession, + usize, + *mut *mut sys::OrtTypeInfo, + ) -> *mut sys::OrtStatus }, + session_ptr: *mut sys::OrtSession, + i: usize, + env: _Environment, + ) -> Result<(TensorElementDataType, Vec>)> { + let mut typeinfo_ptr: *mut sys::OrtTypeInfo = std::ptr::null_mut(); + + let status = unsafe { f(session_ptr, i, &mut typeinfo_ptr) }; + status_to_result(status).map_err(OrtError::GetTypeInfo)?; + assert_not_null_pointer(typeinfo_ptr, "TypeInfo")?; + + let mut tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); + let status = unsafe { + env.env().api().CastTypeInfoToTensorInfo.unwrap()(typeinfo_ptr, &mut tensor_info_ptr) + }; + status_to_result(status).map_err(OrtError::CastTypeInfoToTensorInfo)?; + assert_not_null_pointer(tensor_info_ptr, "TensorInfo")?; + + let mut type_sys = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + let status = unsafe { + env.env().api().GetTensorElementType.unwrap()(tensor_info_ptr, &mut type_sys) + }; + status_to_result(status).map_err(OrtError::TensorElementType)?; + (type_sys != sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) + .then_some(()) + .ok_or(OrtError::UndefinedTensorElementType)?; + // This transmute should be safe since its value is read from GetTensorElementType which we must trust. + let io_type: TensorElementDataType = unsafe { std::mem::transmute(type_sys) }; + + // info!("{} : type={}", i, type_); + + let node_dims = unsafe { get_tensor_dimensions(tensor_info_ptr, env.clone())? }; + + // for j in 0..num_dims { + // info!("{} : dim {}={}", i, j, node_dims[j as usize]); + // } + + unsafe { env.env().api().ReleaseTypeInfo.unwrap()(typeinfo_ptr) }; + + Ok(( + io_type, + node_dims + .into_iter() + .map(|d| if d == -1 { None } else { Some(d as u32) }) + .collect(), + )) + } +} diff --git a/rust/onnxruntime/src/tensor.rs b/rust/onnxruntime/src/tensor.rs new file mode 100644 index 0000000000000..0f383f3ad59b6 --- /dev/null +++ b/rust/onnxruntime/src/tensor.rs @@ -0,0 +1,31 @@ +//! Module containing tensor types. +//! +//! Two main types of tensors are available. +//! +//! The first one, [`Tensor`](struct.Tensor.html), +//! is an _owned_ tensor that is backed by [`ndarray`](https://crates.io/crates/ndarray). +//! This kind of tensor is used to pass input data for the inference. +//! +//! The second one, [`OrtOwnedTensor`](struct.OrtOwnedTensor.html), is used +//! internally to pass to the ONNX Runtime inference execution to place +//! its output values. It is built using a [`OrtOwnedTensorExtractor`](struct.OrtOwnedTensorExtractor.html) +//! following the builder pattern. +//! +//! Once "extracted" from the runtime environment, this tensor will contain an +//! [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html) +//! containing _a view_ of the data. When going out of scope, this tensor will free the required +//! memory on the C side. +//! +//! **NOTE**: Tensors are not meant to be built directly. When performing inference, +//! the [`Session::run()`](../session/struct.Session.html#method.run) method takes +//! an `ndarray::Array` as input (taking ownership of it) and will convert it internally +//! to a [`Tensor`](struct.Tensor.html). After inference, a [`OrtOwnedTensor`](struct.OrtOwnedTensor.html) +//! will be returned by the method which can be derefed into its internal +//! [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html). + +pub mod construct; +pub mod ndarray_tensor; +pub mod ort_input_tensor; +pub mod ort_output_tensor; + +pub use ort_output_tensor::{OrtOutputTensor, WithOutputTensor}; diff --git a/rust/onnxruntime/src/tensor/construct.rs b/rust/onnxruntime/src/tensor/construct.rs new file mode 100644 index 0000000000000..97f70b131ea0a --- /dev/null +++ b/rust/onnxruntime/src/tensor/construct.rs @@ -0,0 +1,34 @@ +//! convert module has the trait for conversion of Inputs ConstructTensor. + +use crate::{memory::MemoryInfo, OrtError}; +use onnxruntime_sys::{OrtAllocator, OrtValue}; +use std::fmt::Debug; + +/// The Input type for Rust onnxruntime Session::run +pub trait ConstructTensor: Debug { + /// Constuct an OrtTensor Input using the `MemoryInfo` and a raw pointer to the `OrtAllocator`. + fn construct<'a>( + &'a mut self, + memory_info: &MemoryInfo, + allocator: *mut OrtAllocator, + ) -> Result, OrtError>; +} + +/// Allows the return value of ConstructTensor::construct +/// to be generic. +pub trait InputTensor { + /// The input tensor's shape + fn shape(&self) -> &[usize]; + + /// The input tensor's ptr + fn ptr(&self) -> *mut OrtValue; +} + +impl<'a, T> From for Box +where + T: ConstructTensor + 'a, +{ + fn from(other: T) -> Self { + Box::new(other) + } +} diff --git a/rust/onnxruntime/src/tensor/ndarray_tensor.rs b/rust/onnxruntime/src/tensor/ndarray_tensor.rs new file mode 100644 index 0000000000000..dea8d161b243b --- /dev/null +++ b/rust/onnxruntime/src/tensor/ndarray_tensor.rs @@ -0,0 +1,210 @@ +//! Module containing a tensor trait extending [`ndarray::ArrayBase`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html) + +use ndarray::{Array, ArrayBase}; + +/// Trait extending [`ndarray::ArrayBase`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html) +/// with useful tensor operations. +/// +/// # Generic +/// +/// The trait is generic over: +/// * `S`: [`ndarray::ArrayBase`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html)'s data container +/// * `T`: Type contained inside the tensor (for example `f32`) +/// * `D`: Tensor's dimension ([`ndarray::Dimension`](https://docs.rs/ndarray/latest/ndarray/trait.Dimension.html)) +pub trait NdArrayTensor { + /// Calculate the [softmax](https://en.wikipedia.org/wiki/Softmax_function) of the tensor along a given axis + /// + /// # Trait Bounds + /// + /// The function is generic and thus has some trait bounds: + /// * `D: ndarray::RemoveAxis`: The summation over an axis reduces the dimension of the tensor. A 0-D tensor thus + /// cannot have a softmax calculated. + /// * `S: ndarray::RawData + ndarray::Data + ndarray::RawData`: The storage of the tensor can be an owned + /// array ([`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html)) or an array view + /// ([`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html)). + /// * `::Elem: std::clone::Clone`: The elements of the tensor must be `Clone`. + /// * `T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign`: The elements of the tensor must be workable + /// as floats and must support `-=` and `/=` operations. + fn softmax(&self, axis: ndarray::Axis) -> Array + where + D: ndarray::RemoveAxis, + S: ndarray::Data + ndarray::RawData, + ::Elem: std::clone::Clone, + T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign; +} + +impl NdArrayTensor for ArrayBase +where + D: ndarray::RemoveAxis, + S: ndarray::Data + ndarray::RawData, + ::Elem: std::clone::Clone, + T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign, +{ + fn softmax(&self, axis: ndarray::Axis) -> Array { + let mut new_array: Array = self.to_owned(); + // FIXME: Change to non-overflowing formula + // e = np.exp(A - np.sum(A, axis=1, keepdims=True)) + // np.exp(a) / np.sum(np.exp(a)) + new_array.map_inplace(|v| *v = v.exp()); + let sum = new_array.sum_axis(axis).insert_axis(axis); + new_array /= ∑ + + new_array + } +} + +#[cfg(test)] +mod tests { + use super::*; + use ndarray::{arr1, arr2, arr3}; + use test_log::test; + + #[test] + fn softmax_1d() { + let array = arr1(&[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0]); + + let expected_softmax = arr1(&[ + 0.023_640_54, + 0.064_261_66, + 0.174_681_3, + 0.474_833, + 0.023_640_54, + 0.064_261_66, + 0.174_681_3, + ]); + + let softmax = array.softmax(ndarray::Axis(0)); + + assert_eq!(softmax.shape(), expected_softmax.shape()); + + let diff = softmax - expected_softmax; + + assert!(diff.iter().all(|d| d.abs() < 1.0e-7)); + } + + #[test] + fn softmax_2d() { + let array = arr2(&[ + [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0], + [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0], + ]); + + let expected_softmax = arr2(&[ + [ + 0.023_640_54, + 0.064_261_66, + 0.174_681_3, + 0.474_833, + 0.023_640_54, + 0.064_261_66, + 0.174_681_3, + ], + [ + 0.023_640_54, + 0.064_261_66, + 0.174_681_3, + 0.474_833, + 0.023_640_54, + 0.064_261_66, + 0.174_681_3, + ], + ]); + + let softmax = array.softmax(ndarray::Axis(1)); + + assert_eq!(softmax.shape(), expected_softmax.shape()); + + let diff = softmax - expected_softmax; + + assert!(diff.iter().all(|d| d.abs() < 1.0e-7)); + } + + #[test] + fn softmax_3d() { + let array = arr3(&[ + [ + [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0], + [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0], + ], + [ + [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0], + [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0], + ], + [ + [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0], + [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0], + ], + ]); + + let expected_softmax = arr3(&[ + [ + [ + 0.023_640_54, + 0.064_261_66, + 0.174_681_3, + 0.474_833, + 0.023_640_54, + 0.064_261_66, + 0.174_681_3, + ], + [ + 0.023_640_54, + 0.064_261_66, + 0.174_681_3, + 0.474_833, + 0.023_640_54, + 0.064_261_66, + 0.174_681_3, + ], + ], + [ + [ + 0.023_640_54, + 0.064_261_66, + 0.174_681_3, + 0.474_833, + 0.023_640_54, + 0.064_261_66, + 0.174_681_3, + ], + [ + 0.023_640_54, + 0.064_261_66, + 0.174_681_3, + 0.474_833, + 0.023_640_54, + 0.064_261_66, + 0.174_681_3, + ], + ], + [ + [ + 0.023_640_54, + 0.064_261_66, + 0.174_681_3, + 0.474_833, + 0.023_640_54, + 0.064_261_66, + 0.174_681_3, + ], + [ + 0.023_640_54, + 0.064_261_66, + 0.174_681_3, + 0.474_833, + 0.023_640_54, + 0.064_261_66, + 0.174_681_3, + ], + ], + ]); + + let softmax = array.softmax(ndarray::Axis(2)); + + assert_eq!(softmax.shape(), expected_softmax.shape()); + + let diff = softmax - expected_softmax; + + assert!(diff.iter().all(|d| d.abs() < 1.0e-7)); + } +} diff --git a/rust/onnxruntime/src/tensor/ort_input_tensor.rs b/rust/onnxruntime/src/tensor/ort_input_tensor.rs new file mode 100644 index 0000000000000..f2cf0ee8a1d4a --- /dev/null +++ b/rust/onnxruntime/src/tensor/ort_input_tensor.rs @@ -0,0 +1,325 @@ +//! Module containing tensor with memory owned by Rust + +use super::construct::{ConstructTensor, InputTensor}; +use crate::{ + environment::ENV, + error::{assert_not_null_pointer, call_ort, status_to_result}, + memory::MemoryInfo, + OrtError, Result, TensorElementDataType, TypeToTensorElementDataType, +}; +use ndarray::{Array, Dimension}; +use onnxruntime_sys as sys; +use std::{ffi, fmt::Debug}; +use sys::OrtAllocator; +use tracing::{debug, error}; + +/// An Input tensor. +/// +/// This ties the lifetime of T to the OrtValue; it is used to copy an +/// [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html) to the runtime's memory. +/// +/// **NOTE**: The type is not meant to be used directly, use an [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html) +/// instead. +#[derive(Debug)] +pub struct OrtInputTensor +where + T: Debug, +{ + pub(crate) c_ptr: *mut sys::OrtValue, + pub(crate) shape: Vec, + #[allow(dead_code)] + item: T, +} + +impl OrtInputTensor +where + T: Debug, +{ + /// The shape of the OrtTensor. + pub fn shape(&self) -> &[usize] { + &self.shape + } +} + +impl ConstructTensor for Array +where + T: TypeToTensorElementDataType + Debug, + D: Dimension, +{ + fn construct<'a>( + &'a mut self, + memory_info: &MemoryInfo, + allocator_ptr: *mut OrtAllocator, + ) -> Result> { + // where onnxruntime will write the tensor data to + let mut tensor_ptr: *mut sys::OrtValue = std::ptr::null_mut(); + let tensor_ptr_ptr: *mut *mut sys::OrtValue = &mut tensor_ptr; + + let sh = self.shape().to_vec(); + + let shape: Vec = self.shape().iter().map(|d: &usize| *d as i64).collect(); + let shape_ptr: *const i64 = shape.as_ptr(); + let shape_len = self.shape().len(); + + match T::tensor_element_data_type() { + TensorElementDataType::Float + | TensorElementDataType::Uint8 + | TensorElementDataType::Int8 + | TensorElementDataType::Uint16 + | TensorElementDataType::Int16 + | TensorElementDataType::Int32 + | TensorElementDataType::Int64 + | TensorElementDataType::Double + | TensorElementDataType::Uint32 + | TensorElementDataType::Uint64 => { + let buffer_size = self.len() * std::mem::size_of::(); + + // primitive data is already suitably laid out in memory; provide it to + // onnxruntime as is + let tensor_values_ptr: *mut std::ffi::c_void = + self.as_mut_ptr().cast::(); + + assert_not_null_pointer(tensor_values_ptr, "TensorValues")?; + + unsafe { + call_ort(|ort| { + ort.CreateTensorWithDataAsOrtValue.unwrap()( + memory_info.ptr, + tensor_values_ptr, + buffer_size, + shape_ptr, + shape_len, + T::tensor_element_data_type().into(), + tensor_ptr_ptr, + ) + }) + } + .map_err(OrtError::CreateTensorWithData)?; + assert_not_null_pointer(tensor_ptr, "Tensor")?; + + let mut is_tensor = 0; + let status = unsafe { + ENV.get().unwrap().lock().unwrap().api().IsTensor.unwrap()( + tensor_ptr, + &mut is_tensor, + ) + }; + status_to_result(status).map_err(OrtError::IsTensor)?; + } + TensorElementDataType::String => { + // create tensor without data -- data is filled in later + unsafe { + call_ort(|ort| { + ort.CreateTensorAsOrtValue.unwrap()( + allocator_ptr, + shape_ptr, + shape_len, + T::tensor_element_data_type().into(), + tensor_ptr_ptr, + ) + }) + } + .map_err(OrtError::CreateTensor)?; + + // create null-terminated copies of each string, as per `FillStringTensor` docs + let null_terminated_copies: Vec = self + .iter() + .map(|elt| { + let slice = elt + .try_utf8_bytes() + .expect("String data type must provide utf8 bytes"); + ffi::CString::new(slice) + }) + .collect::, _>>() + .map_err(OrtError::CStringNulError)?; + + let string_pointers = null_terminated_copies + .iter() + .map(|cstring| cstring.as_ptr()) + .collect::>(); + + unsafe { + call_ort(|ort| { + ort.FillStringTensor.unwrap()( + tensor_ptr, + string_pointers.as_ptr(), + string_pointers.len(), + ) + }) + } + .map_err(OrtError::FillStringTensor)?; + } + } + + assert_not_null_pointer(tensor_ptr, "Tensor")?; + + Ok(Box::new(OrtInputTensor { + c_ptr: tensor_ptr, + shape: sh, + item: self, + })) + } +} + +impl Drop for OrtInputTensor +where + T: Debug, +{ + #[tracing::instrument] + fn drop(&mut self) { + // We need to let the C part free + debug!("Dropping Tensor."); + if self.c_ptr.is_null() { + error!("Null pointer, not calling free."); + } else { + unsafe { + ENV.get() + .unwrap() + .lock() + .unwrap() + .api() + .ReleaseValue + .unwrap()(self.c_ptr) + } + } + + self.c_ptr = std::ptr::null_mut(); + } +} + +impl InputTensor for OrtInputTensor<&mut Array> +where + T: TypeToTensorElementDataType + Debug, + D: Dimension, +{ + fn ptr(&self) -> *mut sys::OrtValue { + self.c_ptr + } + + fn shape(&self) -> &[usize] { + &self.shape + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + environment::{tests::ONNX_RUNTIME_LIBRARY_PATH, Environment}, + AllocatorType, LoggingLevel, MemType, + }; + use ndarray::{arr0, arr1, arr2, arr3}; + use once_cell::sync::Lazy; + use std::env::var; + use test_log::test; + + static ENV: Lazy = Lazy::new(|| { + let path = var(ONNX_RUNTIME_LIBRARY_PATH).ok(); + + let builder = Environment::builder() + .with_name("test") + .with_log_level(LoggingLevel::Warning); + let builder = if let Some(path) = path { + builder.with_library_path(path) + } else { + builder + }; + + builder.build().unwrap() + }); + + #[test] + fn orttensor_from_array_0d_i32() { + let env = &*ENV; + + let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default, env).unwrap(); + let mut array = arr0::(123); + let tensor = array + .construct(&memory_info, ort_default_allocator()) + .unwrap(); + let expected_shape: &[usize] = &[]; + assert_eq!(tensor.shape(), expected_shape); + } + + #[test] + fn orttensor_from_array_1d_i32() { + let env = &*ENV; + + let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default, env).unwrap(); + let mut array = arr1(&[1_i32, 2, 3, 4, 5, 6]); + let tensor = array + .construct(&memory_info, ort_default_allocator()) + .unwrap(); + let expected_shape: &[usize] = &[6]; + assert_eq!(tensor.shape(), expected_shape); + } + + #[test] + fn orttensor_from_array_2d_i32() { + let env = &*ENV; + + let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default, env).unwrap(); + let mut array = arr2(&[[1_i32, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]); + let tensor = array + .construct(&memory_info, ort_default_allocator()) + .unwrap(); + assert_eq!(tensor.shape(), &[2, 6]); + } + + #[test] + fn orttensor_from_array_3d_i32() { + let env = &*ENV; + + let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default, env).unwrap(); + let mut array = arr3(&[ + [[1_i32, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]], + [[13, 14, 15, 16, 17, 18], [19, 20, 21, 22, 23, 24]], + [[25, 26, 27, 28, 29, 30], [31, 32, 33, 34, 35, 36]], + ]); + let tensor = array + .construct(&memory_info, ort_default_allocator()) + .unwrap(); + assert_eq!(tensor.shape(), &[3, 2, 6]); + } + + #[test] + fn orttensor_from_array_1d_string() { + let env = &*ENV; + + let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default, env).unwrap(); + let mut array = arr1(&[ + String::from("foo"), + String::from("bar"), + String::from("baz"), + ]); + let tensor = array + .construct(&memory_info, ort_default_allocator()) + .unwrap(); + assert_eq!(tensor.shape(), &[3]); + } + + #[test] + fn orttensor_from_array_3d_str() { + let env = &*ENV; + + let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default, env).unwrap(); + let mut array = arr3(&[ + [["1", "2", "3"], ["4", "5", "6"]], + [["7", "8", "9"], ["10", "11", "12"]], + ]); + let tensor = array + .construct(&memory_info, ort_default_allocator()) + .unwrap(); + assert_eq!(tensor.shape(), &[2, 2, 3]); + } + + fn ort_default_allocator() -> *mut sys::OrtAllocator { + let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut(); + unsafe { + // this default non-arena allocator doesn't need to be deallocated + call_ort(|ort| ort.GetAllocatorWithDefaultOptions.unwrap()(&mut allocator_ptr)) + } + .unwrap(); + allocator_ptr + } +} diff --git a/rust/onnxruntime/src/tensor/ort_output_tensor.rs b/rust/onnxruntime/src/tensor/ort_output_tensor.rs new file mode 100644 index 0000000000000..5176a58c423ea --- /dev/null +++ b/rust/onnxruntime/src/tensor/ort_output_tensor.rs @@ -0,0 +1,347 @@ +//! Module containing tensor with memory owned by the ONNX Runtime + +use crate::{ + environment::{_Environment, ENV}, + error::status_to_result, + OrtError, Result, TypeToTensorElementDataType, +}; +use ndarray::ArrayView; +use onnxruntime_sys as sys; + +use std::{convert::TryFrom, fmt::Debug}; +use tracing::debug; + +/// Tensor containing data owned by the ONNX Runtime C library, used to return values from inference. +/// +/// This tensor type is returned by the [`Session::run()`](../session/struct.Session.html#method.run) method. +/// It is not meant to be created directly. +#[derive(Debug)] +pub struct OrtOutputTensor { + pub(crate) tensor_ptr: *mut sys::OrtValue, + pub(crate) shape: Vec, + env: _Environment, +} + +#[derive(Debug)] +pub(crate) struct OrtOwnedTensorExtractor { + pub(crate) tensor_ptr: *mut sys::OrtValue, + pub(crate) shape: Vec, + env: _Environment, +} + +impl OrtOwnedTensorExtractor { + pub(crate) fn new(shape: Vec, env: _Environment) -> OrtOwnedTensorExtractor { + OrtOwnedTensorExtractor { + tensor_ptr: std::ptr::null_mut(), + shape, + env, + } + } + + pub(crate) fn extract(self) -> Result { + // Note: Both tensor and array will point to the same data, nothing is copied. + // As such, there is no need too free the pointer used to create the ArrayView. + + assert_ne!(self.tensor_ptr, std::ptr::null_mut()); + + let mut is_tensor = 0; + let status = + unsafe { self.env.env().api().IsTensor.unwrap()(self.tensor_ptr, &mut is_tensor) }; + status_to_result(status).map_err(OrtError::IsTensor)?; + (is_tensor == 1) + .then_some(()) + .ok_or(OrtError::IsTensorCheck)?; + + Ok(OrtOutputTensor { + tensor_ptr: self.tensor_ptr, + shape: self.shape, + env: self.env, + }) + } +} + +impl Drop for OrtOutputTensor { + #[tracing::instrument] + fn drop(&mut self) { + debug!("Dropping OrtOwnedTensor."); + unsafe { self.env.env().api().ReleaseValue.unwrap()(self.tensor_ptr) } + + self.tensor_ptr = std::ptr::null_mut(); + } +} + +/// An Ouput tensor with the ptr and the item that will copy from the ptr. +#[derive(Debug)] +pub struct WithOutputTensor<'a, T> { + #[allow(dead_code)] + pub(crate) tensor: OrtOutputTensor, + item: ArrayView<'a, T, ndarray::IxDyn>, +} + +impl<'a, T> std::ops::Deref for WithOutputTensor<'a, T> { + type Target = ArrayView<'a, T, ndarray::IxDyn>; + + fn deref(&self) -> &Self::Target { + &self.item + } +} + +impl<'a, T> TryFrom for WithOutputTensor<'a, T> +where + T: TypeToTensorElementDataType, +{ + type Error = OrtError; + + fn try_from(value: OrtOutputTensor) -> Result { + // Get pointer to output tensor float values + let mut output_array_ptr: *mut T = std::ptr::null_mut(); + let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; + let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = + output_array_ptr_ptr.cast::<*mut std::ffi::c_void>(); + let status = unsafe { + ENV.get() + .unwrap() + .lock() + .unwrap() + .api() + .GetTensorMutableData + .unwrap()(value.tensor_ptr, output_array_ptr_ptr_void) + }; + status_to_result(status).map_err(OrtError::IsTensor)?; + assert_ne!(output_array_ptr, std::ptr::null_mut()); + + let array_view = + unsafe { ArrayView::from_shape_ptr(ndarray::IxDyn(&value.shape), output_array_ptr) }; + + Ok(WithOutputTensor { + tensor: value, + item: array_view, + }) + } +} + +/// The onnxruntime Run output type. +pub enum OrtOutput<'a> { + /// Tensor of f32s + Float(WithOutputTensor<'a, f32>), + /// Tensor of f64s + Double(WithOutputTensor<'a, f64>), + /// Tensor of u8s + UInt8(WithOutputTensor<'a, u8>), + /// Tensor of u16s + UInt16(WithOutputTensor<'a, u16>), + /// Tensor of u32s + UInt32(WithOutputTensor<'a, u32>), + /// Tensor of u64s + UInt64(WithOutputTensor<'a, u64>), + /// Tensor of i8s + Int8(WithOutputTensor<'a, i8>), + /// Tensor of i16s + Int16(WithOutputTensor<'a, i16>), + /// Tensor of i32s + Int32(WithOutputTensor<'a, i32>), + /// Tensor of i64s + Int64(WithOutputTensor<'a, i64>), + /// Tensor of Strings + String(WithOutputTensor<'a, String>), +} + +impl<'a> OrtOutput<'a> { + /// Return `WithOutputTensor<'a, f32>` which derefs into an `ArrayView`. + pub fn float_array(&self) -> Option<&WithOutputTensor<'a, f32>> { + if let Self::Float(item) = self { + Some(item) + } else { + None + } + } + + /// Return `WithOutputTensor<'a, f64>` which derefs into an `ArrayView`. + pub fn double_array(&self) -> Option<&WithOutputTensor<'a, f64>> { + if let Self::Double(item) = self { + Some(item) + } else { + None + } + } + + /// Return `WithOutputTensor<'a, u8>` which derefs into an `ArrayView`. + pub fn uint8_array(&self) -> Option<&WithOutputTensor<'a, u8>> { + if let Self::UInt8(item) = self { + Some(item) + } else { + None + } + } + + /// Return `WithOutputTensor<'a, u16>` which derefs into an `ArrayView`. + pub fn uint16_array(&self) -> Option<&WithOutputTensor<'a, u16>> { + if let Self::UInt16(item) = self { + Some(item) + } else { + None + } + } + + /// Return `WithOutputTensor<'a, u32>` which derefs into an `ArrayView`. + pub fn uint32_array(&self) -> Option<&WithOutputTensor<'a, u32>> { + if let Self::UInt32(item) = self { + Some(item) + } else { + None + } + } + + /// Return `WithOutputTensor<'a, u64>` which derefs into an `ArrayView`. + pub fn uint64_array(&self) -> Option<&WithOutputTensor<'a, u64>> { + if let Self::UInt64(item) = self { + Some(item) + } else { + None + } + } + + /// Return `WithOutputTensor<'a, i8>` which derefs into an `ArrayView`. + pub fn int8_array(&self) -> Option<&WithOutputTensor<'a, i8>> { + if let Self::Int8(item) = self { + Some(item) + } else { + None + } + } + + /// Return `WithOutputTensor<'a, i16>` which derefs into an `ArrayView`. + pub fn int16_array(&self) -> Option<&WithOutputTensor<'a, i16>> { + if let Self::Int16(item) = self { + Some(item) + } else { + None + } + } + + /// Return `WithOutputTensor<'a, i32>` which derefs into an `ArrayView`. + pub fn int32_array(&self) -> Option<&WithOutputTensor<'a, i32>> { + if let Self::Int32(item) = self { + Some(item) + } else { + None + } + } + + /// Return `WithOutputTensor<'a, i64>` which derefs into an `ArrayView`. + pub fn int64_array(&self) -> Option<&WithOutputTensor<'a, i64>> { + if let Self::Int64(item) = self { + Some(item) + } else { + None + } + } + + /// Return `WithOutputTensor<'a, String>` which derefs into an `ArrayView`. + pub fn string_array(&self) -> Option<&WithOutputTensor<'a, String>> { + if let Self::String(item) = self { + Some(item) + } else { + None + } + } +} + +impl<'a> TryFrom for OrtOutput<'a> { + type Error = OrtError; + + fn try_from(value: OrtOutputTensor) -> Result> { + unsafe { + let mut shape_info = std::ptr::null_mut(); + + let status = ENV + .get() + .unwrap() + .lock() + .unwrap() + .api() + .GetTensorTypeAndShape + .unwrap()(value.tensor_ptr, &mut shape_info); + + status_to_result(status).map_err(OrtError::IsTensor)?; + + assert_ne!(shape_info, std::ptr::null_mut()); + + let mut element_type = + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + + let status = ENV + .get() + .unwrap() + .lock() + .unwrap() + .api() + .GetTensorElementType + .unwrap()(shape_info, &mut element_type); + + status_to_result(status).map_err(OrtError::IsTensor)?; + + ENV.get() + .unwrap() + .lock() + .unwrap() + .api() + .ReleaseTensorTypeAndShapeInfo + .unwrap()(shape_info); + + match element_type { + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED => { + unimplemented!() + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT => { + WithOutputTensor::try_from(value).map(OrtOutput::Float) + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 => { + WithOutputTensor::try_from(value).map(OrtOutput::UInt8) + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 => { + WithOutputTensor::try_from(value).map(OrtOutput::Int8) + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 => { + WithOutputTensor::try_from(value).map(OrtOutput::UInt16) + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 => { + WithOutputTensor::try_from(value).map(OrtOutput::Int16) + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 => { + WithOutputTensor::try_from(value).map(OrtOutput::Int32) + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 => { + WithOutputTensor::try_from(value).map(OrtOutput::Int64) + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING => { + WithOutputTensor::try_from(value).map(OrtOutput::String) + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL => { + unimplemented!() + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 => { + unimplemented!() + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE => { + WithOutputTensor::try_from(value).map(OrtOutput::Double) + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 => { + WithOutputTensor::try_from(value).map(OrtOutput::UInt32) + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 => { + WithOutputTensor::try_from(value).map(OrtOutput::UInt64) + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 => { + unimplemented!() + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 => { + unimplemented!() + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 => { + unimplemented!() + } + } + } + } +} diff --git a/rust/onnxruntime/tests/data/mnist_5.jpg b/rust/onnxruntime/tests/data/mnist_5.jpg new file mode 100644 index 0000000000000..2216a276c4c0a Binary files /dev/null and b/rust/onnxruntime/tests/data/mnist_5.jpg differ diff --git a/rust/onnxruntime/tests/data/mushroom.png b/rust/onnxruntime/tests/data/mushroom.png new file mode 100644 index 0000000000000..2aec0b969a749 Binary files /dev/null and b/rust/onnxruntime/tests/data/mushroom.png differ diff --git a/rust/onnxruntime/tests/data/upsample.onnx b/rust/onnxruntime/tests/data/upsample.onnx new file mode 100644 index 0000000000000..43b2596edcbbd Binary files /dev/null and b/rust/onnxruntime/tests/data/upsample.onnx differ diff --git a/rust/onnxruntime/tests/integration_tests.rs b/rust/onnxruntime/tests/integration_tests.rs new file mode 100644 index 0000000000000..7843fe269e5e4 --- /dev/null +++ b/rust/onnxruntime/tests/integration_tests.rs @@ -0,0 +1,555 @@ +use onnxruntime::{error::OrtDownloadError, tensor::ndarray_tensor::NdArrayTensor}; +use std::{ + fs, + io::{self, BufRead, BufReader}, + path::Path, + sync::Arc, + time::Duration, +}; + +mod download { + use std::env::var; + + use super::*; + const RUST_ONNXRUNTIME_LIBRARY_PATH: &str = "RUST_ONNXRUNTIME_LIBRARY_PATH"; + + use image::{imageops::FilterType, ImageBuffer, Luma, Pixel, Rgb}; + use ndarray::s; + use test_log::test; + + use onnxruntime::{ + download::vision::{DomainBasedImageClassification, ImageClassification}, + environment::Environment, + GraphOptimizationLevel, LoggingLevel, + }; + + #[test] + fn squeezenet_mushroom() { + const IMAGE_TO_LOAD: &str = "mushroom.png"; + + let path = var(RUST_ONNXRUNTIME_LIBRARY_PATH).ok(); + + let environment = { + let builder = Environment::builder() + .with_name("integration_test") + .with_log_level(LoggingLevel::Warning); + let builder = if let Some(path) = path { + builder.with_library_path(path) + } else { + builder + }; + + builder.build().unwrap() + }; + let session = environment + .new_session_builder() + .unwrap() + .with_graph_optimization_level(GraphOptimizationLevel::Basic) + .unwrap() + .with_intra_op_num_threads(1) + .unwrap() + .with_model_downloaded(ImageClassification::SqueezeNet) + .expect("Could not download model from file"); + + let class_labels = get_imagenet_labels().unwrap(); + + let input0_shape: Vec = session.inputs[0].dimensions().map(|d| d.unwrap()).collect(); + let output0_shape: Vec = session.outputs[0] + .dimensions() + .map(|d| d.unwrap()) + .collect(); + + assert_eq!(input0_shape, [1, 3, 224, 224]); + assert_eq!(output0_shape, [1, 1000]); + + // Load image and resize to model's shape, converting to RGB format + let image_buffer: ImageBuffer, Vec> = image::open( + Path::new(env!("CARGO_MANIFEST_DIR")) + .join("tests") + .join("data") + .join(IMAGE_TO_LOAD), + ) + .unwrap() + .resize( + input0_shape[2] as u32, + input0_shape[3] as u32, + FilterType::Nearest, + ) + .to_rgb8(); + + // Python: + // # image[y, x, RGB] + // # x==0 --> left + // # y==0 --> top + + // See https://github.com/onnx/models/blob/main/vision/classification/imagenet_inference.ipynb + // for pre-processing image. + // WARNING: Note order of declaration of arguments: (_,c,j,i) + let mut array = ndarray::Array::from_shape_fn((1, 3, 224, 224), |(_, c, j, i)| { + let pixel = image_buffer.get_pixel(i as u32, j as u32); + let channels = pixel.channels(); + + // range [0, 255] -> range [0, 1] + (channels[c] as f32) / 255.0 + }); + + // Normalize channels to mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225] + let mean = [0.485, 0.456, 0.406]; + let std = [0.229, 0.224, 0.225]; + for c in 0..3 { + let mut channel_array = array.slice_mut(s![0, c, .., ..]); + channel_array -= mean[c]; + channel_array /= std[c]; + } + + // Batch of 1 + let input_tensor_values = vec![array.into()]; + + // Perform the inference + let outputs = session.run(input_tensor_values).unwrap(); + + // Downloaded model does not have a softmax as final layer; call softmax on second axis + // and iterate on resulting probabilities, creating an index to later access labels. + let output = outputs[0].float_array().unwrap(); + let mut probabilities: Vec<(usize, f32)> = output + .softmax(ndarray::Axis(1)) + .iter() + .copied() + .enumerate() + .collect::>(); + // Sort probabilities so highest is at beginning of vector. + probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + + assert_eq!( + class_labels[probabilities[0].0], "n07734744 mushroom", + "Expecting class for {} to be a mushroom", + IMAGE_TO_LOAD + ); + + assert_eq!( + probabilities[0].0, 947, + "Expecting class for {} to be a mushroom (index 947 in labels file)", + IMAGE_TO_LOAD + ); + + // for i in 0..5 { + // println!( + // "class={} ({}); probability={}", + // labels[probabilities[i].0], probabilities[i].0, probabilities[i].1 + // ); + // } + } + + #[test] + fn mnist_5() { + const IMAGE_TO_LOAD: &str = "mnist_5.jpg"; + + let path = var(RUST_ONNXRUNTIME_LIBRARY_PATH).ok(); + + let environment = { + let builder = Environment::builder() + .with_name("integration_test") + .with_log_level(LoggingLevel::Warning); + let builder = if let Some(path) = path { + builder.with_library_path(path) + } else { + builder + }; + + builder.build().unwrap() + }; + + let session = environment + .new_session_builder() + .unwrap() + .with_graph_optimization_level(GraphOptimizationLevel::Basic) + .unwrap() + .with_intra_op_num_threads(1) + .unwrap() + .with_model_downloaded(DomainBasedImageClassification::Mnist) + .expect("Could not download model from file"); + + let input0_shape: Vec = session.inputs[0].dimensions().map(|d| d.unwrap()).collect(); + let output0_shape: Vec = session.outputs[0] + .dimensions() + .map(|d| d.unwrap()) + .collect(); + + assert_eq!(input0_shape, [1, 1, 28, 28]); + assert_eq!(output0_shape, [1, 10]); + + // Load image and resize to model's shape, converting to RGB format + let image_buffer: ImageBuffer, Vec> = image::open( + Path::new(env!("CARGO_MANIFEST_DIR")) + .join("tests") + .join("data") + .join(IMAGE_TO_LOAD), + ) + .unwrap() + .resize( + input0_shape[2] as u32, + input0_shape[3] as u32, + FilterType::Nearest, + ) + .to_luma8(); + + let array = ndarray::Array::from_shape_fn((1, 1, 28, 28), |(_, c, j, i)| { + let pixel = image_buffer.get_pixel(i as u32, j as u32); + let channels = pixel.channels(); + + // range [0, 255] -> range [0, 1] + (channels[c] as f32) / 255.0 + }); + + // Batch of 1 + let input_tensor_values = vec![array.into()]; + + // Perform the inference + let outputs = session.run(input_tensor_values).unwrap(); + + let output = outputs[0].float_array().unwrap(); + let mut probabilities: Vec<(usize, f32)> = output + .softmax(ndarray::Axis(1)) + .iter() + .copied() + .enumerate() + .collect::>(); + + // Sort probabilities so highest is at beginning of vector. + probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + + assert_eq!( + probabilities[0].0, 5, + "Expecting class for {} is '5' (not {})", + IMAGE_TO_LOAD, probabilities[0].0 + ); + } + + #[test] + fn mnist_5_concurrent_session() { + const IMAGE_TO_LOAD: &str = "mnist_5.jpg"; + + let path = var(RUST_ONNXRUNTIME_LIBRARY_PATH).ok(); + + let environment = { + let builder = Environment::builder() + .with_name("integration_test") + .with_log_level(LoggingLevel::Warning); + let builder = if let Some(path) = path { + builder.with_library_path(path) + } else { + builder + }; + + builder.build().unwrap() + }; + + let session = Arc::new( + environment + .new_session_builder() + .unwrap() + .with_graph_optimization_level(GraphOptimizationLevel::Basic) + .unwrap() + .with_intra_op_num_threads(1) + .unwrap() + .with_model_downloaded(DomainBasedImageClassification::Mnist) + .expect("Could not download model from file"), + ); + + let children: Vec> = (0..20) + .map(move |_| { + let session = session.clone(); + std::thread::spawn(move || { + let input0_shape: Vec = + session.inputs[0].dimensions().map(|d| d.unwrap()).collect(); + let output0_shape: Vec = session.outputs[0] + .dimensions() + .map(|d| d.unwrap()) + .collect(); + + assert_eq!(input0_shape, [1, 1, 28, 28]); + assert_eq!(output0_shape, [1, 10]); + + // Load image and resize to model's shape, converting to RGB format + let image_buffer: ImageBuffer, Vec> = image::open( + Path::new(env!("CARGO_MANIFEST_DIR")) + .join("tests") + .join("data") + .join(IMAGE_TO_LOAD), + ) + .unwrap() + .resize( + input0_shape[2] as u32, + input0_shape[3] as u32, + FilterType::Nearest, + ) + .to_luma8(); + + let array = ndarray::Array::from_shape_fn((1, 1, 28, 28), |(_, c, j, i)| { + let pixel = image_buffer.get_pixel(i as u32, j as u32); + let channels = pixel.channels(); + + // range [0, 255] -> range [0, 1] + (channels[c] as f32) / 255.0 + }); + + // Batch of 1 + let input_tensor_values = vec![array.into()]; + + // Perform the inference + let outputs = session.run(input_tensor_values).unwrap(); + + let output = &outputs[0].float_array().unwrap(); + let mut probabilities: Vec<(usize, f32)> = output + .softmax(ndarray::Axis(1)) + .iter() + .copied() + .enumerate() + .collect::>(); + + // Sort probabilities so highest is at beginning of vector. + probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + + assert_eq!( + probabilities[0].0, 5, + "Expecting class for {} is '5' (not {})", + IMAGE_TO_LOAD, probabilities[0].0 + ); + }) + }) + .collect(); + + assert!(children + .into_iter() + .map(std::thread::JoinHandle::join) + .collect::, _>>() + .is_ok()); + } + + #[test] + fn mnist_5_send_session() { + const IMAGE_TO_LOAD: &str = "mnist_5.jpg"; + + let path = var(RUST_ONNXRUNTIME_LIBRARY_PATH).ok(); + + let environment = { + let builder = Environment::builder() + .with_name("integration_test") + .with_log_level(LoggingLevel::Warning); + let builder = if let Some(path) = path { + builder.with_library_path(path) + } else { + builder + }; + + builder.build().unwrap() + }; + + let children: Vec> = (0..20) + .map(|_| { + let session = environment + .new_session_builder() + .unwrap() + .with_graph_optimization_level(GraphOptimizationLevel::Basic) + .unwrap() + .with_intra_op_num_threads(1) + .unwrap() + .with_model_downloaded(DomainBasedImageClassification::Mnist) + .expect("Could not download model from file"); + std::thread::spawn(move || { + let input0_shape: Vec = + session.inputs[0].dimensions().map(|d| d.unwrap()).collect(); + let output0_shape: Vec = session.outputs[0] + .dimensions() + .map(|d| d.unwrap()) + .collect(); + + assert_eq!(input0_shape, [1, 1, 28, 28]); + assert_eq!(output0_shape, [1, 10]); + + // Load image and resize to model's shape, converting to RGB format + let image_buffer: ImageBuffer, Vec> = image::open( + Path::new(env!("CARGO_MANIFEST_DIR")) + .join("tests") + .join("data") + .join(IMAGE_TO_LOAD), + ) + .unwrap() + .resize( + input0_shape[2] as u32, + input0_shape[3] as u32, + FilterType::Nearest, + ) + .to_luma8(); + + let array = ndarray::Array::from_shape_fn((1, 1, 28, 28), |(_, c, j, i)| { + let pixel = image_buffer.get_pixel(i as u32, j as u32); + let channels = pixel.channels(); + + // range [0, 255] -> range [0, 1] + (channels[c] as f32) / 255.0 + }); + + // Batch of 1 + let input_tensor_values = vec![array.into()]; + + // Perform the inference + let outputs = session.run(input_tensor_values).unwrap(); + + let output = &outputs[0].float_array().unwrap(); + let mut probabilities: Vec<(usize, f32)> = output + .softmax(ndarray::Axis(1)) + .iter() + .copied() + .enumerate() + .collect::>(); + + // Sort probabilities so highest is at beginning of vector. + probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + + assert_eq!( + probabilities[0].0, 5, + "Expecting class for {} is '5' (not {})", + IMAGE_TO_LOAD, probabilities[0].0 + ); + }) + }) + .collect(); + + assert!(children + .into_iter() + .map(std::thread::JoinHandle::join) + .collect::, _>>() + .is_ok()); + } + + // This test verifies that dynamically sized inputs and outputs work. It loads and runs + // upsample.onnx, which was produced via: + // + // ``` + // import subprocess + // from tensorflow import keras + // + // m = keras.Sequential([ + // keras.layers.UpSampling2D(size=2) + // ]) + // m.build(input_shape=(None, None, None, 3)) + // m.summary() + // m.save('saved_model') + // + // subprocess.check_call([ + // 'python', '-m', 'tf2onnx.convert', + // '--saved-model', 'saved_model', + // '--opset', '12', + // '--output', 'upsample.onnx', + // ]) + // ``` + #[test] + fn upsample() { + const IMAGE_TO_LOAD: &str = "mushroom.png"; + + let path = var(RUST_ONNXRUNTIME_LIBRARY_PATH).ok(); + + let environment = { + let builder = Environment::builder() + .with_name("integration_test") + .with_log_level(LoggingLevel::Warning); + let builder = if let Some(path) = path { + builder.with_library_path(path) + } else { + builder + }; + + builder.build().unwrap() + }; + + let session = environment + .new_session_builder() + .unwrap() + .with_graph_optimization_level(GraphOptimizationLevel::Basic) + .unwrap() + .with_intra_op_num_threads(1) + .unwrap() + .with_model_from_file( + Path::new(env!("CARGO_MANIFEST_DIR")) + .join("tests") + .join("data") + .join("upsample.onnx"), + ) + .expect("Could not open model from file"); + + assert_eq!( + session.inputs[0].dimensions().collect::>(), + [None, None, None, Some(3)] + ); + assert_eq!( + session.outputs[0].dimensions().collect::>(), + [None, None, None, Some(3)] + ); + + // Load image, converting to RGB format + let image_buffer: ImageBuffer, Vec> = image::open( + Path::new(env!("CARGO_MANIFEST_DIR")) + .join("tests") + .join("data") + .join(IMAGE_TO_LOAD), + ) + .unwrap() + .to_rgb8(); + + let array = ndarray::Array::from_shape_fn((1, 224, 224, 3), |(_, j, i, c)| { + let pixel = image_buffer.get_pixel(i as u32, j as u32); + let channels = pixel.channels(); + + // range [0, 255] -> range [0, 1] + (channels[c] as f32) / 255.0 + }); + + // Just one input + let input_tensor_values = vec![array.into()]; + + // Perform the inference + let outputs = session.run(input_tensor_values).unwrap(); + + assert_eq!(outputs.len(), 1); + let output = outputs[0].float_array().unwrap(); + + // The image should have doubled in size + assert_eq!(output.shape(), [1, 448, 448, 3]); + } +} + +fn get_imagenet_labels() -> Result, OrtDownloadError> { + // Download the ImageNet class labels, matching SqueezeNet's classes. + let labels_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("synset.txt"); + if !labels_path.exists() { + let url = "https://s3.amazonaws.com/onnx-model-zoo/synset.txt"; + println!("Downloading {:?} to {:?}...", url, labels_path); + let resp = ureq::get(url) + .timeout(Duration::from_secs(180)) // 3 minutes + .call() + .map_err(Box::new) + .map_err(OrtDownloadError::UreqError)?; + + assert!(resp.has("Content-Length")); + let len = resp + .header("Content-Length") + .and_then(|s| s.parse::().ok()) + .unwrap(); + println!("Downloading {} bytes...", len); + + let mut reader = resp.into_reader(); + + let f = fs::File::create(&labels_path).unwrap(); + let mut writer = io::BufWriter::new(f); + + let bytes_io_count = io::copy(&mut reader, &mut writer).unwrap(); + + assert_eq!(bytes_io_count, len as u64); + } + let file = BufReader::new(fs::File::open(labels_path).unwrap()); + + file.lines() + .map(|line| line.map_err(|io_err| OrtDownloadError::IoError(io_err))) + .collect() +} diff --git a/rust/rustfmt.toml b/rust/rustfmt.toml new file mode 100644 index 0000000000000..267219dda5f37 --- /dev/null +++ b/rust/rustfmt.toml @@ -0,0 +1,2 @@ +format_code_in_doc_comments = true +imports_granularity = "Crate"